Branch 154885009 (#9604)

* Enable grappler to propagate shapes through queues.
Change: 154789133

* Add whitelist support in uid of RunConfig.
Change: 154794859

* Fix a bunch of bad links and missing docs in contrib.
Change: 154820641

* Don't try to refine the shapes for a node if its inference context wasn't
successfully built by the AddNode() method.
Change: 154838211

* Fix issue related to empty bazel.rc file.
Change: 154840138

* Remove overly precise CHECK when rendering debug output for a function.

An `_Arg` node can have more than three attrs, because the runtime may
(and does) add system-defined attrs (viz. "_output_shapes") that do
not change the meaning of the op.
Change: 154850526

* Port makefile build breakage
Change: 154855106

* [TF:XLA] Try to incorporate Tensorflow node structure for large HLO GraphDefs.

This change assumes that a TF subgraph/op does not cross the boundary of a HLO
computation and always put top-level TF subgraphs/ops under HLO computations.
Change: 154855884

* Added a unit test to check what happens when 2 shapes with known rank but
unknown dimensions are merged
Change: 154856675

* [XLA] Refactor constant folding operations into a dedicated module

Refactor constant folding operations into a dedicated module, and added a new
ReplaceInstruction() API to collapse { computation->ReplaceInstruction();
changed=true}.
Change: 154857025

* Java: Docs: Update instructions for Windows.

Inspired by
http://stackoverflow.com/questions/43741775/tensorflow-in-java-running-failed
Change: 154859066

* Add more documentation for features and labels.
Change: 154859649

* Added link to high-performance models
Change: 154860213

* Navigation and index for new performance section documents.
Change: 154862215

* Fix shape mismatch between loss and weights.
Change: 154862650

* Add examples to TensorShape documentation and ran autoformatter.
Change: 154862667

* Move linking of cudnn_plugin, cublas_plugin and cufft_plugin from
stream_executor to the ops that need them.
Change: 154863520

* Properly track the persistent memory usage of lookup tables.
Change: 154866686

* Reset the inputs to ShapeRefiner::RunShapeFn so that it behaves the same every time it's called.
To properly handle queues that have populated by several enqueue ops, merge the shapes of the inputs to all the enqueue ops before calling InferenceContext::set_output_handle_shape(). This ensures that we detect incorrect queue setups (where the 2 enqueue ops might generate tensors with incompatible shapes), and that we take all the known shape information instead of that of just one of the enqueue ops.
Change: 154866747

* Making sure an error message will be produced by session_manager when a non-tensor object is passed in.
Otherwise the 'name' property is missing.
Change: 154868022

* Don't needlessly synchronize the CUDA stream in CropAndResize.
Make the op Async so we don't block an executor thread while waiting for the result of the box bounds check to be copied back to the host.
Change: 154868460

* Add contribution guidelines and standards section to CONTRIBUTING.md

Several parts are largely based on the post by @yaroslavvb at: #7443#issuecomment-279182613

Fixes #7443
Change: 154876045

* Final draft
Change: 154876563

* Final draft
Change: 154876646

* Fix losses documentation.

Fix documentation of get_total_loss() to be correct.
And add a helpful comment about a common pitfall.
Change: 154876822

* [XLA] Second change for HLO interpreter.

Extends HloEvaluator to allow evaluation of HLO Computation or single HLO instruction
with non-constant operands, by traversing the instruction in post order and keeps track of
each instruction along the way as evaluated literals.
Change: 154877580

* [tf distributions] Move the remaining whitelisted distributions to core.
Change: 154878206

* Add shape to error message.
Change: 154880260

* Revert "Fix build issue when `/usr/bin/python` path is not available (#9547)"

This reverts commit 95f37ebf0b.
This commit is contained in:
Vijay Vasudevan 2017-05-03 12:25:10 -07:00 committed by GitHub
parent 2e329cc898
commit 22586bdf90
102 changed files with 3120 additions and 1764 deletions

1
.gitignore vendored
View File

@ -5,7 +5,6 @@ node_modules
/.tf_configure.bazelrc /.tf_configure.bazelrc
/bazel-* /bazel-*
/third_party/py/numpy/numpy_include /third_party/py/numpy/numpy_include
/tools/bazel.rc
/tools/python_bin_path.sh /tools/python_bin_path.sh
/tools/git/gen /tools/git/gen
/util/python/python_include /util/python/python_include

View File

@ -27,3 +27,140 @@ contributions, often because we probably won't get to them right now. If you
decide to start on an issue, leave a comment so that other people know that decide to start on an issue, leave a comment so that other people know that
you're working on it. If you want to help out, but not alone, use the issue you're working on it. If you want to help out, but not alone, use the issue
comment thread to coordinate. comment thread to coordinate.
### Contribution guidelines and standards
Before sending your pull request for
[review](https://github.com/tensorflow/tensorflow/pulls),
make sure your changes are consistent with the guidelines and follow the
TensorFlow coding style.
#### General guidelines and philosophy for contribution
* Include unit tests when you contribute new features, as they help to
a) prove that your code works correctly, b) guard against future breaking
changes to lower the maintenance cost.
* Bug fixes also generally require unit tests, because the presense of bugs
usually indicates insufficient test coverage.
* Keep API compatibility in mind when you change code in core TensorFlow,
e.g., code in [tensorflow/core](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core) and [tensorflow/python](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python).
TensorFlow has reached version 1 and hence cannot make
non-backward-compatible API changes without a major release. Reviewers of your
pull request will comment on any API compatibility issues.
* When you contribute a new feature to TensorFlow, the maintenance burden is (by
default) transferred to the TensorFlow team. This means that benefit of
contribution must be compared against the cost of maintaining the feature.
* Full new features (e.g., a new op implementing a cutting-edge algorithm)
typically will live in
[tensorflow/contrib](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib)
to get some airtime before decision is made regarding whether they are to be
migrated to the core.
#### License
Include a license at the top of new files.
* [C/C++ license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/op.cc#L1)
* [Python license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/nn.py#L1)
* [Java license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/Graph.java#L1)
* [Go license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/go/operation.go#L1)
* [Bash license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/ci_build/ci_sanity.sh#L2)
* [HTML license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tensorboard/dist/index.html#L2)
* [JavaScript/TypeScript license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tensorboard/components/tf_backend/backend.ts#L1)
Bazel BUILD files also need to include a license section, e.g.,
[BUILD example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/BUILD#L61).
#### C++ coding style
Changes to TensorFlow C++ code should conform to
[Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html).
Use `clang-tidy` to check your C/C++ changes. To install clang-tidy on ubuntu:16.04, do:
```bash
apt-get install -y clang-tidy
```
You can check a C/C++ file by doing:
```bash
clang-format <my_cc_file> --style=google > /tmp/my_cc_file.cc
diff <my_cc_file> /tmp/my_cc_file.cc
```
#### Python coding style
Changes to TensorFlow Python code should conform to
[Google Python Style Guide](https://google.github.io/styleguide/pyguide.html)
Use `pylint` to check your Python changes. To install `pylint` and
retrieve TensorFlow's custom style definition:
```bash
pip install pylint
wget -O /tmp/pylintrc https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/tools/ci_build/pylintrc
```
To check a file with `pylint`:
```bash
pylint --rcfile=/tmp/pylintrc myfile.py
```
#### Coding style for other languages
* [Google Java Style Guide](https://google.github.io/styleguide/javaguide.html)
* [Google JavaScript Style Guide](https://google.github.io/styleguide/jsguide.html)
* [Google Shell Style Guide](https://google.github.io/styleguide/shell.xml)
#### Running sanity check
If you have Docker installed on your system, you can perform a sanity check on
your changes by running the command:
```bash
tensorflow/tools/ci_build/ci_build.sh CPU tensorflow/tools/ci_build/ci_sanity.sh
```
This will catch most license, Python coding style and BUILD file issues that
may exist in your changes.
#### Running unit tests
There are two ways to run TensorFlow unit tests.
1. Using tools and libraries installed directly on your system.
Refer to the
[CPU-only developer Dockerfile](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/docker/Dockerfile.devel) and
[GPU developer Dockerfile](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/docker/Dockerfile.devel-gpu)
for the required packages. Alternatively, use the said
[Docker images](https://hub.docker.com/r/tensorflow/tensorflow/tags/), e.g.,
`tensorflow/tensorflow:nightly-devel` and `tensorflow/tensorflow:nightly-devel-gpu`
for development to avoid installing the packages directly on your system.
Once you have the packages installed, you can run a specific unit test in
bazel by doing as follows:
If the tests are to be run on GPU, add CUDA paths to LD_LIBRARY_PATH and add
the `cuda` option flag
```bash
export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH"
export flags="--config=opt --config=cuda -k"
```
For example, to run all tests under tensorflow/python, do:
```bash
bazel test ${flags} //tensorflow/python/...
```
2. Using Docker and TensorFlow's CI scripts.
See
[TensorFlow Builds](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/ci_build) for details.

3
configure vendored
View File

@ -353,9 +353,8 @@ if [[ "$TF_NEED_VERBS" == "1" ]]; then
fi fi
# Append CC optimization flags to bazel.rc # Append CC optimization flags to bazel.rc
echo >> tools/bazel.rc
for opt in $CC_OPT_FLAGS; do for opt in $CC_OPT_FLAGS; do
echo "build:opt --cxxopt=$opt --copt=$opt" >> tools/bazel.rc write_to_bazelrc 'build:opt --cxxopt=$opt --copt=$opt'
done done
# Run the gen_git_source to create links where bazel can track dependencies for # Run the gen_git_source to create links where bazel can track dependencies for

View File

@ -80,6 +80,8 @@ cc_library(
":hlo_query", ":hlo_query",
"//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:util",
@ -1418,6 +1420,27 @@ cc_library(
], ],
) )
cc_test(
name = "hlo_constant_folding_test",
srcs = ["hlo_constant_folding_test.cc"],
deps = [
":cpu_plugin",
":hlo",
":hlo_constant_folding",
":hlo_matchers",
":hlo_pass",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/core:lib",
"//tensorflow/core:test_main",
],
)
cc_library( cc_library(
name = "device_memory_allocator", name = "device_memory_allocator",
srcs = ["device_memory_allocator.cc"], srcs = ["device_memory_allocator.cc"],

View File

@ -219,12 +219,6 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
HloInstruction* operand, HloInstruction* max, HloInstruction* operand, HloInstruction* max,
HloInstruction* max_operand); HloInstruction* max_operand);
// Tries to constant fold a concatenate operation, and returns true if the
// operation has been performed. An error status is returned in case of error.
StatusOr<bool> TryConcatenateConstantFold(
HloInstruction* concatenate,
tensorflow::gtl::ArraySlice<HloInstruction*> operands);
// A Reshape or Broadcast that feeds an element-wise operation with a unique // A Reshape or Broadcast that feeds an element-wise operation with a unique
// non-scalar operand can sink to after the operation. // non-scalar operand can sink to after the operation.
StatusOr<bool> TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand( StatusOr<bool> TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand(
@ -236,12 +230,23 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
Status ReplaceWithNewInstruction( Status ReplaceWithNewInstruction(
HloInstruction* old_instruction, HloInstruction* old_instruction,
std::unique_ptr<HloInstruction> new_instruction) { std::unique_ptr<HloInstruction> new_instruction) {
TF_CHECK_OK(computation_->ReplaceWithNewInstruction( TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
old_instruction, std::move(new_instruction))); old_instruction, std::move(new_instruction)));
changed_ = true; changed_ = true;
return Status::OK(); return Status::OK();
} }
// Replaces the existing HLO instruction old_instruction, with
// new_instruction, and marks the optimizer status as changed.
// Returns the Status representing the result of the replace operation.
Status ReplaceInstruction(HloInstruction* old_instruction,
HloInstruction* new_instruction) {
TF_RETURN_IF_ERROR(
computation_->ReplaceInstruction(old_instruction, new_instruction));
changed_ = true;
return Status::OK();
}
// Current HloComputation instance the AlgebraicSimplifierVisitor is // Current HloComputation instance the AlgebraicSimplifierVisitor is
// traversing. // traversing.
HloComputation* computation_; HloComputation* computation_;
@ -290,8 +295,7 @@ void AlgebraicSimplifierVisitor::ReplaceWithBitcast(
auto bitcast = computation_->AddInstruction( auto bitcast = computation_->AddInstruction(
HloInstruction::CreateUnary(instruction->shape(), HloOpcode::kBitcast, HloInstruction::CreateUnary(instruction->shape(), HloOpcode::kBitcast,
instruction->mutable_operand(0))); instruction->mutable_operand(0)));
TF_CHECK_OK(computation_->ReplaceInstruction(instruction, bitcast)); TF_CHECK_OK(ReplaceInstruction(instruction, bitcast));
changed_ = true;
} }
bool AlgebraicSimplifierVisitor::ReplaceInstructionIfSameShape( bool AlgebraicSimplifierVisitor::ReplaceInstructionIfSameShape(
@ -299,9 +303,7 @@ bool AlgebraicSimplifierVisitor::ReplaceInstructionIfSameShape(
if (!SameShape(old_instruction, new_instruction)) { if (!SameShape(old_instruction, new_instruction)) {
return false; return false;
} }
TF_CHECK_OK( TF_CHECK_OK(ReplaceInstruction(old_instruction, new_instruction));
computation_->ReplaceInstruction(old_instruction, new_instruction));
changed_ = true;
return true; return true;
} }
@ -329,63 +331,6 @@ Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy,
return Status::OK(); return Status::OK();
} }
StatusOr<bool> AlgebraicSimplifierVisitor::TryConcatenateConstantFold(
HloInstruction* concatenate,
tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
if (operands[0]->opcode() == HloOpcode::kConstant) {
// If all the operands of a concatenate are constant, fold them into a
// single constant tensor.
// The concatenate dimension is going to be the sum of all the concatenate
// dimensions.
int64 concat_dim = concatenate->dimensions()[0];
const Shape& reference_shape = operands[0]->shape();
if (ShapeUtil::IsTuple(reference_shape)) {
VLOG(5) << "Tuples not currently supported by the concatenate constant"
" folder";
return false;
}
int64 rank = ShapeUtil::Rank(reference_shape);
std::vector<int64> concat_dimensions(reference_shape.dimensions().begin(),
reference_shape.dimensions().end());
if (concat_dim < 0) {
concat_dim += rank;
}
for (int64 i = 1; i < operands.size(); ++i) {
const Shape& operand_shape = operands[i]->shape();
if (operands[i]->opcode() != HloOpcode::kConstant ||
ShapeUtil::IsTuple(operand_shape)) {
return false;
}
// Accumulate the concat dimension from all tensors taking part to the
// operation.
concat_dimensions[concat_dim] +=
ShapeUtil::GetDimension(operand_shape, concat_dim);
}
auto literal = LiteralUtil::CreateFromDimensions(
reference_shape.element_type(), concat_dimensions);
std::vector<int64> source_indices(rank, 0);
std::vector<int64> dest_indices(concat_dimensions.size(), 0);
for (auto operand : operands) {
const Shape& operand_shape = operand->shape();
Status status = LiteralUtil::Copy(
operand->literal(), source_indices, literal.get(), dest_indices,
AsInt64Slice(operand_shape.dimensions()));
if (!status.ok()) {
VLOG(1) << "Error while creating concatenated literal : " << status;
return false;
}
dest_indices[concat_dim] +=
ShapeUtil::GetDimension(operand_shape, concat_dim);
}
TF_CHECK_OK(computation_->ReplaceWithNewInstruction(
concatenate, HloInstruction::CreateConstant(std::move(literal))));
changed_ = true;
return true;
}
return false;
}
Status AlgebraicSimplifierVisitor::HandleConcatenate( Status AlgebraicSimplifierVisitor::HandleConcatenate(
HloInstruction* concatenate, HloInstruction* concatenate,
tensorflow::gtl::ArraySlice<HloInstruction*> operands) { tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
@ -394,13 +339,6 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate(
ReplaceInstructionIfSameShape(concatenate, operands[0]); ReplaceInstructionIfSameShape(concatenate, operands[0]);
return Status::OK(); return Status::OK();
} }
// If all the concatenate operands are constant, this will get folded into a
// new constant literal.
TF_ASSIGN_OR_RETURN(bool folded,
TryConcatenateConstantFold(concatenate, operands));
if (folded) {
return Status::OK();
}
// Filter out and remove empty operands. // Filter out and remove empty operands.
std::vector<HloInstruction*> nonempty_operands; std::vector<HloInstruction*> nonempty_operands;
for (HloInstruction* operand : operands) { for (HloInstruction* operand : operands) {
@ -799,65 +737,6 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
return Status::OK(); return Status::OK();
} }
template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
static std::unique_ptr<HloInstruction> ConvertIfTypesMatch(
const Literal& src_literal) {
CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
return HloInstruction::CreateConstant(
LiteralUtil::Convert<typename primitive_util::PrimitiveTypeToNative<
primitive_src_type>::type,
typename primitive_util::PrimitiveTypeToNative<
primitive_dest_type>::type>(src_literal));
}
template <PrimitiveType primitive_src_type>
static std::unique_ptr<HloInstruction> ConvertIfDestTypeMatches(
const Literal& src_literal, PrimitiveType primitive_dest_type) {
switch (primitive_dest_type) {
#define CONVERT_IF_TYPES_MATCH(type) \
case (type): \
return ConvertIfTypesMatch<primitive_src_type, (type)>(src_literal);
CONVERT_IF_TYPES_MATCH(PRED)
CONVERT_IF_TYPES_MATCH(S8)
CONVERT_IF_TYPES_MATCH(S32)
CONVERT_IF_TYPES_MATCH(S64)
CONVERT_IF_TYPES_MATCH(U8)
CONVERT_IF_TYPES_MATCH(U32)
CONVERT_IF_TYPES_MATCH(U64)
CONVERT_IF_TYPES_MATCH(F32)
CONVERT_IF_TYPES_MATCH(F64)
#undef CONVERT_IF_TYPES_MATCH
// Other types are not yet supported.
default:
LOG(FATAL) << "Unimplemented: ConvertIfDestTypeMatches for type "
<< PrimitiveType_Name(src_literal.shape().element_type());
}
}
static std::unique_ptr<HloInstruction> ConvertIfSrcTypeMatches(
const Literal& src_literal, PrimitiveType primitive_dest_type) {
switch (src_literal.shape().element_type()) {
#define CONVERT_IF_DEST_TYPE_MATCHES(type) \
case (type): \
return ConvertIfDestTypeMatches<(type)>(src_literal, primitive_dest_type);
CONVERT_IF_DEST_TYPE_MATCHES(PRED)
CONVERT_IF_DEST_TYPE_MATCHES(S8)
CONVERT_IF_DEST_TYPE_MATCHES(S32)
CONVERT_IF_DEST_TYPE_MATCHES(S64)
CONVERT_IF_DEST_TYPE_MATCHES(U8)
CONVERT_IF_DEST_TYPE_MATCHES(U32)
CONVERT_IF_DEST_TYPE_MATCHES(U64)
CONVERT_IF_DEST_TYPE_MATCHES(F32)
CONVERT_IF_DEST_TYPE_MATCHES(F64)
#undef CONVERT_IF_DEST_TYPE_MATCHES
// Other types are not yet supported.
default:
LOG(FATAL) << "Unimplemented: ConvertIfSrcTypeMatches for type "
<< PrimitiveType_Name(src_literal.shape().element_type());
}
}
// A conversion to the same element type as the operand is a nop and can be // A conversion to the same element type as the operand is a nop and can be
// removed. A conversion of a constant can be simplified by making a new // removed. A conversion of a constant can be simplified by making a new
// constant. // constant.
@ -866,14 +745,7 @@ Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert,
PrimitiveType src_type = operand->shape().element_type(); PrimitiveType src_type = operand->shape().element_type();
PrimitiveType dest_type = convert->shape().element_type(); PrimitiveType dest_type = convert->shape().element_type();
if (src_type == dest_type) { if (src_type == dest_type) {
changed_ = true; return ReplaceInstruction(convert, operand);
return computation_->ReplaceInstruction(convert, operand);
}
if (operand->opcode() == HloOpcode::kConstant) {
const Literal& src_literal = operand->literal();
std::unique_ptr<HloInstruction> new_constant =
ConvertIfSrcTypeMatches(src_literal, dest_type);
return ReplaceWithNewInstruction(convert, std::move(new_constant));
} }
return Status::OK(); return Status::OK();
} }
@ -1080,8 +952,7 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) {
// Delete no-op reshapes, i.e. where shape = operand shape. // Delete no-op reshapes, i.e. where shape = operand shape.
if (SameShape(reshape, operand)) { if (SameShape(reshape, operand)) {
VLOG(10) << "deleting no-op reshape"; VLOG(10) << "deleting no-op reshape";
changed_ = true; return ReplaceInstruction(reshape, operand);
return computation_->ReplaceInstruction(reshape, operand);
} }
// Merge reshapes. // Merge reshapes.
@ -1131,8 +1002,7 @@ Status AlgebraicSimplifierVisitor::HandleReverse(HloInstruction* reverse,
}; };
if (std::all_of(reverse->dimensions().begin(), reverse->dimensions().end(), if (std::all_of(reverse->dimensions().begin(), reverse->dimensions().end(),
dim_is_one)) { dim_is_one)) {
changed_ = true; return ReplaceInstruction(reverse, operand);
return computation_->ReplaceInstruction(reverse, operand);
} }
return Status::OK(); return Status::OK();
} }
@ -1143,21 +1013,6 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice,
if (ReplaceInstructionIfSameShape(slice, operand)) { if (ReplaceInstructionIfSameShape(slice, operand)) {
return Status::OK(); return Status::OK();
} }
if (operand->opcode() == HloOpcode::kConstant) {
const Shape& shape = slice->shape();
auto literal = LiteralUtil::CreateFromDimensions(
shape.element_type(), AsInt64Slice(shape.dimensions()));
std::vector<int64> dest_indices(slice->slice_starts().size(), 0);
Status status = LiteralUtil::Copy(operand->literal(), slice->slice_starts(),
literal.get(), dest_indices,
AsInt64Slice(shape.dimensions()));
if (status.ok()) {
TF_CHECK_OK(ReplaceWithNewInstruction(
slice, HloInstruction::CreateConstant(std::move(literal))));
} else {
VLOG(1) << "Error while creating sliced literal : " << status;
}
}
return Status::OK(); return Status::OK();
} }
@ -1247,8 +1102,7 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) {
if (std::is_sorted(transpose->dimensions().begin(), if (std::is_sorted(transpose->dimensions().begin(),
transpose->dimensions().end())) { transpose->dimensions().end())) {
VLOG(10) << "deleting no-op transpose"; VLOG(10) << "deleting no-op transpose";
changed_ = true; return ReplaceInstruction(transpose, operand);
return computation_->ReplaceInstruction(transpose, operand);
} }
if (HloOpcode::kTranspose == operand->opcode()) { if (HloOpcode::kTranspose == operand->opcode()) {
@ -1379,9 +1233,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
auto new_rhs = add_bitcast(new_filter_shape, rhs); auto new_rhs = add_bitcast(new_filter_shape, rhs);
auto dot = computation_->AddInstruction(HloInstruction::CreateBinary( auto dot = computation_->AddInstruction(HloInstruction::CreateBinary(
dot_output_shape, HloOpcode::kDot, new_lhs, new_rhs)); dot_output_shape, HloOpcode::kDot, new_lhs, new_rhs));
changed_ = true; return ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot));
return computation_->ReplaceInstruction(convolution,
add_bitcast(convolution_shape, dot));
} }
bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape( bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape(

View File

@ -466,75 +466,6 @@ TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) {
EXPECT_THAT(computation->root_instruction(), input); EXPECT_THAT(computation->root_instruction(), input);
} }
TEST_F(AlgebraicSimplifierTest, ConvertF32ToS64) {
HloComputation::Builder builder(TestName());
HloInstruction* input = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
builder.AddInstruction(
HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {}), input));
auto module = MakeUnique<HloModule>(TestName());
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Convert(input));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), op::Constant());
EXPECT_EQ(LiteralUtil::GetFirstElement<int64>(
computation->root_instruction()->literal()),
42);
}
TEST_F(AlgebraicSimplifierTest, ConvertS64ToF32) {
HloComputation::Builder builder(TestName());
HloInstruction* input = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int64>(42)));
builder.AddInstruction(
HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input));
auto module = MakeUnique<HloModule>(TestName());
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Convert(input));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), op::Constant());
EXPECT_EQ(LiteralUtil::GetFirstElement<float>(
computation->root_instruction()->literal()),
42.0f);
}
TEST_F(AlgebraicSimplifierTest, ConvertF32ArrayToS64Array) {
HloComputation::Builder builder(TestName());
HloInstruction* input = builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR1<float>({42.0f, 19.0f})));
builder.AddInstruction(
HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {2}), input));
auto module = MakeUnique<HloModule>(TestName());
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Convert(input));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), op::Constant());
EXPECT_EQ(
LiteralUtil::Get<int64>(computation->root_instruction()->literal(), {0}),
42);
EXPECT_EQ(
LiteralUtil::Get<int64>(computation->root_instruction()->literal(), {1}),
19);
}
// Test that copies are removed. // Test that copies are removed.
TEST_F(AlgebraicSimplifierTest, RemoveCopy) { TEST_F(AlgebraicSimplifierTest, RemoveCopy) {
Shape r0f32 = ShapeUtil::MakeShape(F32, {}); Shape r0f32 = ShapeUtil::MakeShape(F32, {});
@ -1666,69 +1597,5 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) {
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
} }
TEST_F(AlgebraicSimplifierTest, Concatenate) {
const struct TestConfig {
int concat_dimension;
tensorflow::gtl::ArraySlice<int64> dimensions;
tensorflow::gtl::ArraySlice<int64> concat_sizes;
} test_configs[] = {
{1, {11, 0, 7, 5, 9}, {2, 5, 7, 11}},
{3, {1, 4, 17, 0, 8}, {1, 3, 9, 12}},
};
for (auto& test_config : test_configs) {
HloComputation::Builder builder(TestName());
std::vector<int64> dimensions(test_config.dimensions.begin(),
test_config.dimensions.end());
int64 concat_size = 0;
std::vector<HloInstruction*> operands;
for (auto csize : test_config.concat_sizes) {
dimensions[test_config.concat_dimension] = csize;
concat_size += csize;
auto literal = LiteralUtil::CreateFromDimensions(F32, dimensions);
HloInstruction* insn = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
operands.push_back(insn);
}
dimensions[test_config.concat_dimension] = concat_size;
Shape shape = ShapeUtil::MakeShape(F32, dimensions);
builder.AddInstruction(HloInstruction::CreateConcatenate(
shape, operands, test_config.concat_dimension));
HloModule module(TestName());
auto computation = module.AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root->opcode(), HloOpcode::kConstant);
EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape));
}
}
TEST_F(AlgebraicSimplifierTest, Slice) {
HloComputation::Builder builder(TestName());
const int64 dimensions[] = {11, 8, 7, 5, 9};
const int64 slice_start[] = {4, 2, 3, 1, 5};
const int64 slice_limits[] = {10, 8, 6, 5, 9};
auto literal = LiteralUtil::CreateFromDimensions(F32, dimensions);
HloInstruction* lit_insn = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
Shape shape = ShapeUtil::MakeShape(F32, {6, 6, 3, 4, 4});
builder.AddInstruction(
HloInstruction::CreateSlice(shape, lit_insn, slice_start, slice_limits));
HloModule module(TestName());
auto computation = module.AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root->opcode(), HloOpcode::kConstant);
EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape));
}
} // namespace } // namespace
} // namespace xla } // namespace xla

View File

@ -264,6 +264,8 @@ cc_library(
"//tensorflow/compiler/xla/service:tuple_points_to_analysis", "//tensorflow/compiler/xla/service:tuple_points_to_analysis",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core/platform/default/build_config:cublas_plugin",
"//tensorflow/core/platform/default/build_config:cudnn_plugin",
"//tensorflow/core/platform/default/build_config:stream_executor_cuda", "//tensorflow/core/platform/default/build_config:stream_executor_cuda",
], ],
) )

View File

@ -15,16 +15,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_constant_folding.h" #include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
#include <list>
#include <map>
#include <memory> #include <memory>
#include <set>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h"
@ -34,52 +32,222 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
namespace xla { namespace xla {
namespace {
template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
static std::unique_ptr<Literal> ConvertIfTypesMatch(
const Literal& src_literal) {
CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
return LiteralUtil::Convert<
typename primitive_util::PrimitiveTypeToNative<primitive_src_type>::type,
typename primitive_util::PrimitiveTypeToNative<
primitive_dest_type>::type>(src_literal);
}
template <PrimitiveType primitive_src_type>
static std::unique_ptr<Literal> ConvertIfDestTypeMatches(
const Literal& src_literal, PrimitiveType primitive_dest_type) {
switch (primitive_dest_type) {
#define CONVERT_IF_TYPES_MATCH(type) \
case (type): \
return ConvertIfTypesMatch<primitive_src_type, (type)>(src_literal);
CONVERT_IF_TYPES_MATCH(PRED)
CONVERT_IF_TYPES_MATCH(S8)
CONVERT_IF_TYPES_MATCH(S32)
CONVERT_IF_TYPES_MATCH(S64)
CONVERT_IF_TYPES_MATCH(U8)
CONVERT_IF_TYPES_MATCH(U32)
CONVERT_IF_TYPES_MATCH(U64)
CONVERT_IF_TYPES_MATCH(F32)
CONVERT_IF_TYPES_MATCH(F64)
#undef CONVERT_IF_TYPES_MATCH
// Other types are not yet supported.
default:
LOG(FATAL) << "Unimplemented: ConvertIfDestTypeMatches for type "
<< PrimitiveType_Name(src_literal.shape().element_type());
}
}
static std::unique_ptr<Literal> ConvertIfSrcTypeMatches(
const Literal& src_literal, PrimitiveType primitive_dest_type) {
switch (src_literal.shape().element_type()) {
#define CONVERT_IF_DEST_TYPE_MATCHES(type) \
case (type): \
return ConvertIfDestTypeMatches<(type)>(src_literal, primitive_dest_type);
CONVERT_IF_DEST_TYPE_MATCHES(PRED)
CONVERT_IF_DEST_TYPE_MATCHES(S8)
CONVERT_IF_DEST_TYPE_MATCHES(S32)
CONVERT_IF_DEST_TYPE_MATCHES(S64)
CONVERT_IF_DEST_TYPE_MATCHES(U8)
CONVERT_IF_DEST_TYPE_MATCHES(U32)
CONVERT_IF_DEST_TYPE_MATCHES(U64)
CONVERT_IF_DEST_TYPE_MATCHES(F32)
CONVERT_IF_DEST_TYPE_MATCHES(F64)
#undef CONVERT_IF_DEST_TYPE_MATCHES
// Other types are not yet supported.
default:
LOG(FATAL) << "Unimplemented: ConvertIfSrcTypeMatches for type "
<< PrimitiveType_Name(src_literal.shape().element_type());
}
}
} // namespace
// ConstantFolderVisitor traverses the HLO computation and reduces certain
// constant graph sections, to literals.
class ConstantFolderVisitor : public DfsHloVisitorWithDefault {
public:
// Default visitor action is to do nothing and return OK.
Status DefaultAction(HloInstruction* /*hlo_instruction*/) override {
return Status::OK();
}
Status HandleConcatenate(
HloInstruction* concatenate,
tensorflow::gtl::ArraySlice<HloInstruction*> operands) override;
Status HandleConvert(HloInstruction* convert,
HloInstruction* operand) override;
Status HandleReshape(HloInstruction* reshape) override;
Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override;
Status HandleTranspose(HloInstruction* transpose) override;
// Returns whether a constant folding operation has occurred.
const bool changed() const { return changed_; }
// Runs the visitor on a computation and returns whether any changes were
// performed.
static StatusOr<bool> Run(HloComputation* computation);
private:
ConstantFolderVisitor() = default;
// Replaces the existing HLO instruction old_instruction, with a literal,
// and marks the optimizer status as changed.
// Returns the Status representing the result of the replace operation.
Status ReplaceWithConstant(HloInstruction* old_instruction,
std::unique_ptr<Literal> literal) {
TF_RETURN_IF_ERROR(old_instruction->parent()->ReplaceWithNewInstruction(
old_instruction, HloInstruction::CreateConstant(std::move(literal))));
changed_ = true;
return Status::OK();
}
// Whether any constant folding operations have occurred.
bool changed_ = false;
};
StatusOr<bool> ConstantFolderVisitor::Run(HloComputation* computation) {
ConstantFolderVisitor visitor;
TF_RETURN_IF_ERROR(computation->Accept(&visitor));
return visitor.changed();
}
StatusOr<bool> HloConstantFolding::Run(HloModule* module) { StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
XLA_VLOG_LINES(2,
"HloConstantFolding::Run(), before:\n" + module->ToString());
bool changed = false; bool changed = false;
for (auto& computation : module->computations()) { for (auto& comp : module->computations()) {
for (auto instruction : computation->MakeInstructionPostOrder()) { TF_ASSIGN_OR_RETURN(bool result, ConstantFolderVisitor::Run(comp.get()));
// Skip dead code. changed = changed || result;
if (instruction->user_count() == 0 &&
computation->root_instruction() != instruction) {
continue;
}
// Depending on the opcode, choose how to handle constant operands.
//
// TODO(b/35975797): Fold constant computations for more than reshapes and
// transposes.
switch (instruction->opcode()) {
case HloOpcode::kReshape: {
if (instruction->operand(0)->opcode() == HloOpcode::kConstant) {
TF_ASSIGN_OR_RETURN(
auto reshaped_literal,
LiteralUtil::Reshape(
instruction->operand(0)->literal(),
AsInt64Slice(instruction->shape().dimensions())));
TF_CHECK_OK(computation->ReplaceWithNewInstruction(
instruction,
HloInstruction::CreateConstant(std::move(reshaped_literal))));
changed = true;
}
break;
}
case HloOpcode::kTranspose: {
if (instruction->operand(0)->opcode() == HloOpcode::kConstant) {
auto transposed_literal = LiteralUtil::Transpose(
instruction->operand(0)->literal(), instruction->dimensions());
TF_CHECK_OK(computation->ReplaceWithNewInstruction(
instruction,
HloInstruction::CreateConstant(std::move(transposed_literal))));
changed = true;
}
break;
}
default:
break;
}
}
} }
XLA_VLOG_LINES(2, "HloConstantFolding::Run(), after:\n" + module->ToString());
return changed; return changed;
} }
Status ConstantFolderVisitor::HandleReshape(HloInstruction* reshape) {
if (reshape->operand(0)->opcode() == HloOpcode::kConstant) {
TF_ASSIGN_OR_RETURN(
auto reshaped_literal,
LiteralUtil::Reshape(reshape->operand(0)->literal(),
AsInt64Slice(reshape->shape().dimensions())));
return ReplaceWithConstant(reshape, std::move(reshaped_literal));
}
return Status::OK();
}
Status ConstantFolderVisitor::HandleTranspose(HloInstruction* transpose) {
if (transpose->operand(0)->opcode() == HloOpcode::kConstant) {
auto transposed_literal = LiteralUtil::Transpose(
transpose->operand(0)->literal(), transpose->dimensions());
return ReplaceWithConstant(transpose, std::move(transposed_literal));
}
return Status::OK();
}
Status ConstantFolderVisitor::HandleConcatenate(
HloInstruction* concatenate,
tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
if (operands[0]->opcode() == HloOpcode::kConstant) {
// If all the operands of a concatenate are constant, fold them into a
// single constant tensor.
// The result concatenate dimension is going to be the sum of all the
// concatenate dimensions of the arrays taking part of the operation.
int64 concat_dim = concatenate->dimensions()[0];
const Shape& reference_shape = operands[0]->shape();
CHECK(!ShapeUtil::IsTuple(reference_shape));
int64 rank = ShapeUtil::Rank(reference_shape);
std::vector<int64> concat_dimensions(reference_shape.dimensions().begin(),
reference_shape.dimensions().end());
if (concat_dim < 0) {
concat_dim += rank;
}
for (int64 i = 1; i < operands.size(); ++i) {
const Shape& operand_shape = operands[i]->shape();
CHECK(!ShapeUtil::IsTuple(operand_shape));
if (operands[i]->opcode() != HloOpcode::kConstant) {
return Status::OK();
}
// Accumulate the concat dimension from all tensors taking part to the
// operation.
concat_dimensions[concat_dim] +=
ShapeUtil::GetDimension(operand_shape, concat_dim);
}
auto literal = LiteralUtil::CreateFromDimensions(
reference_shape.element_type(), concat_dimensions);
std::vector<int64> source_indices(rank, 0);
std::vector<int64> dest_indices(concat_dimensions.size(), 0);
for (auto operand : operands) {
const Shape& operand_shape = operand->shape();
TF_RETURN_IF_ERROR(LiteralUtil::Copy(
operand->literal(), source_indices, literal.get(), dest_indices,
AsInt64Slice(operand_shape.dimensions())));
dest_indices[concat_dim] +=
ShapeUtil::GetDimension(operand_shape, concat_dim);
}
return ReplaceWithConstant(concatenate, std::move(literal));
}
return Status::OK();
}
Status ConstantFolderVisitor::HandleSlice(HloInstruction* slice,
HloInstruction* operand) {
if (operand->opcode() == HloOpcode::kConstant) {
const Shape& shape = slice->shape();
auto literal = LiteralUtil::CreateFromDimensions(
shape.element_type(), AsInt64Slice(shape.dimensions()));
std::vector<int64> dest_indices(slice->slice_starts().size(), 0);
TF_RETURN_IF_ERROR(LiteralUtil::Copy(
operand->literal(), slice->slice_starts(), literal.get(), dest_indices,
AsInt64Slice(shape.dimensions())));
TF_RETURN_IF_ERROR(ReplaceWithConstant(slice, std::move(literal)));
}
return Status::OK();
}
Status ConstantFolderVisitor::HandleConvert(HloInstruction* convert,
HloInstruction* operand) {
if (operand->opcode() == HloOpcode::kConstant) {
const Literal& src_literal = operand->literal();
std::unique_ptr<Literal> new_constant =
ConvertIfSrcTypeMatches(src_literal, convert->shape().element_type());
return ReplaceWithConstant(convert, std::move(new_constant));
}
return Status::OK();
}
} // namespace xla } // namespace xla

View File

@ -25,12 +25,10 @@ namespace xla {
// computation on constants. // computation on constants.
class HloConstantFolding : public HloPassInterface { class HloConstantFolding : public HloPassInterface {
public: public:
explicit HloConstantFolding() {}
~HloConstantFolding() override {}
tensorflow::StringPiece name() const override { return "constant_folding"; } tensorflow::StringPiece name() const override { return "constant_folding"; }
// Run ConstantFolding on the given module. Returns whether the module was // Run constant folding operations on the given module. Returns whether the
// changed (common subexpressions were found and eliminated). // module was changed (constant expressions folded).
StatusOr<bool> Run(HloModule* module) override; StatusOr<bool> Run(HloModule* module) override;
}; };

View File

@ -0,0 +1,169 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
#include <memory>
#include <utility>
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
namespace op = xla::testing::opcode_matchers;
namespace xla {
namespace {
using HloConstantFoldingTest = HloTestBase;
TEST_F(HloConstantFoldingTest, ConvertF32ToS64) {
HloComputation::Builder builder(TestName());
HloInstruction* input = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
builder.AddInstruction(
HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {}), input));
auto module = MakeUnique<HloModule>(TestName());
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Convert(input));
HloConstantFolding simplifier;
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), op::Constant());
EXPECT_EQ(LiteralUtil::GetFirstElement<int64>(
computation->root_instruction()->literal()),
42);
}
TEST_F(HloConstantFoldingTest, ConvertS64ToF32) {
HloComputation::Builder builder(TestName());
HloInstruction* input = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int64>(42)));
builder.AddInstruction(
HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input));
auto module = MakeUnique<HloModule>(TestName());
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Convert(input));
HloConstantFolding simplifier;
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), op::Constant());
EXPECT_EQ(LiteralUtil::GetFirstElement<float>(
computation->root_instruction()->literal()),
42.0f);
}
TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) {
HloComputation::Builder builder(TestName());
HloInstruction* input = builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR1<float>({42.0f, 19.0f})));
builder.AddInstruction(
HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {2}), input));
auto module = MakeUnique<HloModule>(TestName());
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Convert(input));
HloConstantFolding simplifier;
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), op::Constant());
EXPECT_EQ(
LiteralUtil::Get<int64>(computation->root_instruction()->literal(), {0}),
42);
EXPECT_EQ(
LiteralUtil::Get<int64>(computation->root_instruction()->literal(), {1}),
19);
}
TEST_F(HloConstantFoldingTest, Concatenate) {
const struct TestConfig {
int concat_dimension;
tensorflow::gtl::ArraySlice<int64> dimensions;
tensorflow::gtl::ArraySlice<int64> concat_sizes;
} test_configs[] = {
{1, {11, 0, 7, 5, 9}, {2, 5, 7, 11}},
{3, {1, 4, 17, 0, 8}, {1, 3, 9, 12}},
};
for (auto& test_config : test_configs) {
HloComputation::Builder builder(TestName());
std::vector<int64> dimensions(test_config.dimensions.begin(),
test_config.dimensions.end());
int64 concat_size = 0;
std::vector<HloInstruction*> operands;
for (auto csize : test_config.concat_sizes) {
dimensions[test_config.concat_dimension] = csize;
concat_size += csize;
auto literal = LiteralUtil::CreateFromDimensions(F32, dimensions);
HloInstruction* insn = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
operands.push_back(insn);
}
dimensions[test_config.concat_dimension] = concat_size;
Shape shape = ShapeUtil::MakeShape(F32, dimensions);
builder.AddInstruction(HloInstruction::CreateConcatenate(
shape, operands, test_config.concat_dimension));
HloModule module(TestName());
auto computation = module.AddEntryComputation(builder.Build());
HloConstantFolding simplifier;
ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
HloInstruction* root = computation->root_instruction();
EXPECT_THAT(root, op::Constant());
EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape));
}
}
TEST_F(HloConstantFoldingTest, Slice) {
HloComputation::Builder builder(TestName());
const int64 dimensions[] = {11, 8, 7, 5, 9};
const int64 slice_start[] = {4, 2, 3, 1, 5};
const int64 slice_limits[] = {10, 8, 6, 5, 9};
auto literal = LiteralUtil::CreateFromDimensions(F32, dimensions);
HloInstruction* lit_insn = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
Shape shape = ShapeUtil::MakeShape(F32, {6, 6, 3, 4, 4});
builder.AddInstruction(
HloInstruction::CreateSlice(shape, lit_insn, slice_start, slice_limits));
HloModule module(TestName());
auto computation = module.AddEntryComputation(builder.Build());
HloConstantFolding simplifier;
ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
HloInstruction* root = computation->root_instruction();
EXPECT_THAT(root, op::Constant());
EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape));
}
} // namespace
} // namespace xla

View File

@ -26,20 +26,21 @@ limitations under the License.
#include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/index_util.h"
#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_query.h" #include "tensorflow/compiler/xla/service/hlo_query.h"
#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/bitmap.h" #include "tensorflow/core/lib/core/bitmap.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
@ -53,9 +54,7 @@ std::unique_ptr<Literal> ElementWiseUnaryOp(
const Literal& operand) { const Literal& operand) {
DCHECK(ShapeUtil::SameDimensions(shape, operand.shape())); DCHECK(ShapeUtil::SameDimensions(shape, operand.shape()));
auto result = MakeUnique<Literal>(); auto result = LiteralUtil::CreateFromShape(shape);
*result->mutable_shape() = shape;
LiteralUtil::Reserve(ShapeUtil::ElementsIn(shape), result.get());
std::vector<int64> multi_index(ShapeUtil::Rank(result->shape()), 0); std::vector<int64> multi_index(ShapeUtil::Rank(result->shape()), 0);
do { do {
@ -74,9 +73,7 @@ std::unique_ptr<Literal> ElementWiseBinaryOp(
DCHECK(ShapeUtil::SameDimensions(shape, rhs.shape())); DCHECK(ShapeUtil::SameDimensions(shape, rhs.shape()));
DCHECK(ShapeUtil::SameDimensions(lhs.shape(), rhs.shape())); DCHECK(ShapeUtil::SameDimensions(lhs.shape(), rhs.shape()));
auto result = MakeUnique<Literal>(); auto result = LiteralUtil::CreateFromShape(shape);
*result->mutable_shape() = shape;
LiteralUtil::Reserve(ShapeUtil::ElementsIn(shape), result.get());
std::vector<int64> multi_index(ShapeUtil::Rank(result->shape()), 0); std::vector<int64> multi_index(ShapeUtil::Rank(result->shape()), 0);
do { do {
@ -99,9 +96,7 @@ std::unique_ptr<Literal> ElementWiseTernaryOp(
DCHECK(ShapeUtil::SameDimensions(lhs.shape(), rhs.shape())); DCHECK(ShapeUtil::SameDimensions(lhs.shape(), rhs.shape()));
DCHECK(ShapeUtil::SameDimensions(rhs.shape(), ehs.shape())); DCHECK(ShapeUtil::SameDimensions(rhs.shape(), ehs.shape()));
auto result = MakeUnique<Literal>(); auto result = LiteralUtil::CreateFromShape(shape);
*result->mutable_shape() = shape;
LiteralUtil::Reserve(ShapeUtil::ElementsIn(shape), result.get());
std::vector<int64> multi_index(ShapeUtil::Rank(result->shape()), 0); std::vector<int64> multi_index(ShapeUtil::Rank(result->shape()), 0);
do { do {
@ -130,29 +125,130 @@ NativeT AbsoluteVal(NativeT value) {
return std::abs(value); return std::abs(value);
} }
template <typename NativeT> } // namespace
StatusOr<std::unique_ptr<Literal>> EvaluateOpForLiteralInternal(
HloInstruction* instruction) {
DCHECK(hlo_query::AllOperandsAreConstants(*instruction));
Status HloEvaluator::DefaultAction(HloInstruction* hlo) {
VLOG(2) << "Handle instruction: " << hlo->ToString();
Shape shape = hlo->shape();
TF_CHECK_OK(ShapeUtil::ValidateShape(shape));
TF_ASSIGN_OR_RETURN(evaluated_[hlo], EvaluateBasedOnType(hlo));
return Status::OK();
}
Status HloEvaluator::HandleParameter(HloInstruction* parameter) {
VLOG(2) << "HandleParameter: " << parameter->ToString();
const Literal* input_literal = arg_literals_[parameter->parameter_number()];
VLOG(2) << "Parameter evaluated to: "
<< LiteralUtil::ToString(*input_literal);
CHECK(ShapeUtil::Equal(parameter->shape(), input_literal->shape()));
evaluated_[parameter] = MakeUnique<Literal>(*input_literal);
return Status::OK();
}
Status HloEvaluator::HandleConstant(HloInstruction* constant,
const Literal& literal) {
VLOG(2) << "HandleConstant: " << constant->ToString();
CHECK(ShapeUtil::Equal(constant->shape(), literal.shape()));
evaluated_[constant] = MakeUnique<Literal>(literal);
return Status::OK();
}
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
HloComputation* computation,
tensorflow::gtl::ArraySlice<const Literal*> args) {
arg_literals_ = args;
TF_RETURN_IF_ERROR(computation->Accept(this));
return std::move(FindOrDie(evaluated_, computation->root_instruction()));
}
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
HloInstruction* instruction,
tensorflow::gtl::ArraySlice<const Literal*> args) {
DCHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction));
Shape shape = instruction->shape();
TF_CHECK_OK(ShapeUtil::ValidateShape(shape));
arg_literals_ = args;
// Evaluate operands of Parameter type against the input literals which caches
// the evaluated literal results.
for (const auto operand : instruction->operands()) {
if (operand->opcode() == HloOpcode::kParameter) {
TF_CHECK_OK(HandleParameter(operand));
} else if (operand->opcode() == HloOpcode::kConstant) {
evaluated_[operand] = MakeUnique<Literal>(operand->literal());
}
}
TF_RETURN_IF_ERROR(instruction->Visit(this));
return std::move(FindOrDie(evaluated_, instruction));
}
StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateBasedOnType(
HloInstruction* instruction) {
Shape shape = instruction->shape();
TF_CHECK_OK(ShapeUtil::ValidateShape(shape));
switch (shape.element_type()) {
case PRED:
return EvaluateSameTypedElementwise<bool>(instruction);
case U8:
return EvaluateSameTypedElementwise<uint8>(instruction);
case U16:
return Unimplemented("unhandled primitive type: %s.",
PrimitiveType_Name(U16).c_str());
case U32:
return EvaluateSameTypedElementwise<uint32>(instruction);
case U64:
return EvaluateSameTypedElementwise<uint64>(instruction);
case S8:
return EvaluateSameTypedElementwise<int8>(instruction);
case S16:
return Unimplemented("unhandled primitive type: %s.",
PrimitiveType_Name(S16).c_str());
case S32:
return EvaluateSameTypedElementwise<int32>(instruction);
case S64:
return EvaluateSameTypedElementwise<int64>(instruction);
case F16:
return Unimplemented("unhandled primitive type: %s.",
PrimitiveType_Name(F16).c_str());
case F32:
return EvaluateSameTypedElementwise<float>(instruction);
case F64:
return EvaluateSameTypedElementwise<double>(instruction);
default:
return Unimplemented("unhandled primitive type: %s.",
PrimitiveType_Name(shape.element_type()).c_str());
}
}
template <typename NativeT>
StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateSameTypedElementwise(
HloInstruction* instruction) {
const std::vector<HloInstruction*>& operands = instruction->operands(); const std::vector<HloInstruction*>& operands = instruction->operands();
HloOpcode opcode = instruction->opcode(); HloOpcode opcode = instruction->opcode();
const Shape& shape = instruction->shape(); const Shape& shape = instruction->shape();
switch (opcode) { switch (opcode) {
// TODO(b/35950897): many of the stl function used here are not overloaded // TODO(b/35950897): many of the stl function used here are not overloaded
// for all XLA primitive types. // for every XLA primitive types.
// Unary element-wise ops. // Unary element-wise ops.
//
case HloOpcode::kAbs: case HloOpcode::kAbs:
CHECK_EQ(operands.size(), 1); CHECK_EQ(operands.size(), 1);
return ElementWiseUnaryOp<NativeT>( return ElementWiseUnaryOp<NativeT>(
shape, [](NativeT operand) { return AbsoluteVal(operand); }, shape, [](NativeT operand) { return AbsoluteVal(operand); },
operands[0]->literal()); GetEvaluatedLiteralFor(operands[0]));
case HloOpcode::kCeil: case HloOpcode::kCeil:
CHECK_EQ(operands.size(), 1); CHECK_EQ(operands.size(), 1);
return ElementWiseUnaryOp<NativeT>( return ElementWiseUnaryOp<NativeT>(
shape, [](NativeT operand) { return std::ceil(operand); }, shape, [](NativeT operand) { return std::ceil(operand); },
operands[0]->literal()); GetEvaluatedLiteralFor(operands[0]));
case HloOpcode::kConvert: case HloOpcode::kConvert:
CHECK_EQ(operands.size(), 1); CHECK_EQ(operands.size(), 1);
// TODO(b/35950897): implement Convert. // TODO(b/35950897): implement Convert.
@ -162,37 +258,37 @@ StatusOr<std::unique_ptr<Literal>> EvaluateOpForLiteralInternal(
CHECK_EQ(operands.size(), 1); CHECK_EQ(operands.size(), 1);
return ElementWiseUnaryOp<NativeT>( return ElementWiseUnaryOp<NativeT>(
shape, [](NativeT operand) { return operand; }, shape, [](NativeT operand) { return operand; },
operands[0]->literal()); GetEvaluatedLiteralFor(operands[0]));
case HloOpcode::kExp: case HloOpcode::kExp:
CHECK_EQ(operands.size(), 1); CHECK_EQ(operands.size(), 1);
return ElementWiseUnaryOp<NativeT>( return ElementWiseUnaryOp<NativeT>(
shape, [](NativeT operand) { return std::exp(operand); }, shape, [](NativeT operand) { return std::exp(operand); },
operands[0]->literal()); GetEvaluatedLiteralFor(operands[0]));
case HloOpcode::kFloor: case HloOpcode::kFloor:
CHECK_EQ(operands.size(), 1); CHECK_EQ(operands.size(), 1);
return ElementWiseUnaryOp<NativeT>( return ElementWiseUnaryOp<NativeT>(
shape, [](NativeT operand) { return std::floor(operand); }, shape, [](NativeT operand) { return std::floor(operand); },
operands[0]->literal()); GetEvaluatedLiteralFor(operands[0]));
case HloOpcode::kIsFinite: case HloOpcode::kIsFinite:
CHECK_EQ(operands.size(), 1); CHECK_EQ(operands.size(), 1);
return ElementWiseUnaryOp<NativeT>( return ElementWiseUnaryOp<NativeT>(
shape, [](NativeT operand) { return std::isfinite(operand); }, shape, [](NativeT operand) { return std::isfinite(operand); },
operands[0]->literal()); GetEvaluatedLiteralFor(operands[0]));
case HloOpcode::kLog: case HloOpcode::kLog:
CHECK_EQ(operands.size(), 1); CHECK_EQ(operands.size(), 1);
return ElementWiseUnaryOp<NativeT>( return ElementWiseUnaryOp<NativeT>(
shape, [](NativeT operand) { return std::log(operand); }, shape, [](NativeT operand) { return std::log(operand); },
operands[0]->literal()); GetEvaluatedLiteralFor(operands[0]));
case HloOpcode::kLogicalNot: case HloOpcode::kLogicalNot:
CHECK_EQ(operands.size(), 1); CHECK_EQ(operands.size(), 1);
return ElementWiseUnaryOp<NativeT>( return ElementWiseUnaryOp<NativeT>(
shape, [](NativeT operand) { return !operand; }, shape, [](NativeT operand) { return !operand; },
operands[0]->literal()); GetEvaluatedLiteralFor(operands[0]));
case HloOpcode::kNegate: case HloOpcode::kNegate:
CHECK_EQ(operands.size(), 1); CHECK_EQ(operands.size(), 1);
return ElementWiseUnaryOp<NativeT>( return ElementWiseUnaryOp<NativeT>(
shape, [](NativeT operand) { return -operand; }, shape, [](NativeT operand) { return -operand; },
operands[0]->literal()); GetEvaluatedLiteralFor(operands[0]));
case HloOpcode::kSign: case HloOpcode::kSign:
CHECK_EQ(operands.size(), 1); CHECK_EQ(operands.size(), 1);
CHECK(primitive_util::IsIntegralType(shape.element_type())); CHECK(primitive_util::IsIntegralType(shape.element_type()));
@ -201,95 +297,113 @@ StatusOr<std::unique_ptr<Literal>> EvaluateOpForLiteralInternal(
return (NativeT(0) < operand) - return (NativeT(0) < operand) -
(operand < NativeT(0)); (operand < NativeT(0));
}, },
operands[0]->literal()); GetEvaluatedLiteralFor(operands[0]));
case HloOpcode::kTanh: case HloOpcode::kTanh:
CHECK_EQ(operands.size(), 1); CHECK_EQ(operands.size(), 1);
return ElementWiseUnaryOp<NativeT>( return ElementWiseUnaryOp<NativeT>(
shape, [](NativeT operand) { return std::tanh(operand); }, shape, [](NativeT operand) { return std::tanh(operand); },
operands[0]->literal()); GetEvaluatedLiteralFor(operands[0]));
// Binary element-wise ops. // Binary element-wise ops.
//
case HloOpcode::kAdd: case HloOpcode::kAdd:
CHECK_EQ(operands.size(), 2); CHECK_EQ(operands.size(), 2);
return ElementWiseBinaryOp<NativeT>( return ElementWiseBinaryOp<NativeT>(
shape, [](NativeT lhs, NativeT rhs) { return lhs + rhs; }, shape, [](NativeT lhs, NativeT rhs) { return lhs + rhs; },
operands[0]->literal(), operands[1]->literal()); GetEvaluatedLiteralFor(operands[0]),
GetEvaluatedLiteralFor(operands[1]));
case HloOpcode::kDivide: case HloOpcode::kDivide:
CHECK_EQ(operands.size(), 2); CHECK_EQ(operands.size(), 2);
return ElementWiseBinaryOp<NativeT>( return ElementWiseBinaryOp<NativeT>(
shape, [](NativeT lhs, NativeT rhs) { return lhs / rhs; }, shape, [](NativeT lhs, NativeT rhs) { return lhs / rhs; },
operands[0]->literal(), operands[1]->literal()); GetEvaluatedLiteralFor(operands[0]),
GetEvaluatedLiteralFor(operands[1]));
case HloOpcode::kMultiply: case HloOpcode::kMultiply:
CHECK_EQ(operands.size(), 2); CHECK_EQ(operands.size(), 2);
return ElementWiseBinaryOp<NativeT>( return ElementWiseBinaryOp<NativeT>(
shape, [](NativeT lhs, NativeT rhs) { return lhs * rhs; }, shape, [](NativeT lhs, NativeT rhs) { return lhs * rhs; },
operands[0]->literal(), operands[1]->literal()); GetEvaluatedLiteralFor(operands[0]),
GetEvaluatedLiteralFor(operands[1]));
case HloOpcode::kSubtract: case HloOpcode::kSubtract:
CHECK_EQ(operands.size(), 2); CHECK_EQ(operands.size(), 2);
return ElementWiseBinaryOp<NativeT>( return ElementWiseBinaryOp<NativeT>(
shape, [](NativeT lhs, NativeT rhs) { return lhs - rhs; }, shape, [](NativeT lhs, NativeT rhs) { return lhs - rhs; },
operands[0]->literal(), operands[1]->literal()); GetEvaluatedLiteralFor(operands[0]),
GetEvaluatedLiteralFor(operands[1]));
case HloOpcode::kEq: case HloOpcode::kEq:
CHECK_EQ(operands.size(), 2); CHECK_EQ(operands.size(), 2);
return ElementWiseBinaryOp<bool>( return ElementWiseBinaryOp<bool>(
shape, [](NativeT lhs, NativeT rhs) { return lhs == rhs; }, shape, [](NativeT lhs, NativeT rhs) { return lhs == rhs; },
operands[0]->literal(), operands[1]->literal()); GetEvaluatedLiteralFor(operands[0]),
GetEvaluatedLiteralFor(operands[1]));
case HloOpcode::kGe: case HloOpcode::kGe:
CHECK_EQ(operands.size(), 2); CHECK_EQ(operands.size(), 2);
return ElementWiseBinaryOp<bool>( return ElementWiseBinaryOp<bool>(
shape, [](NativeT lhs, NativeT rhs) { return lhs >= rhs; }, shape, [](NativeT lhs, NativeT rhs) { return lhs >= rhs; },
operands[0]->literal(), operands[1]->literal()); GetEvaluatedLiteralFor(operands[0]),
GetEvaluatedLiteralFor(operands[1]));
case HloOpcode::kGt: case HloOpcode::kGt:
CHECK_EQ(operands.size(), 2); CHECK_EQ(operands.size(), 2);
return ElementWiseBinaryOp<bool>( return ElementWiseBinaryOp<bool>(
shape, [](NativeT lhs, NativeT rhs) { return lhs > rhs; }, shape, [](NativeT lhs, NativeT rhs) { return lhs > rhs; },
operands[0]->literal(), operands[1]->literal()); GetEvaluatedLiteralFor(operands[0]),
GetEvaluatedLiteralFor(operands[1]));
case HloOpcode::kLe: case HloOpcode::kLe:
CHECK_EQ(operands.size(), 2); CHECK_EQ(operands.size(), 2);
return ElementWiseBinaryOp<bool>( return ElementWiseBinaryOp<bool>(
shape, [](NativeT lhs, NativeT rhs) { return lhs <= rhs; }, shape, [](NativeT lhs, NativeT rhs) { return lhs <= rhs; },
operands[0]->literal(), operands[1]->literal()); GetEvaluatedLiteralFor(operands[0]),
GetEvaluatedLiteralFor(operands[1]));
case HloOpcode::kLt: case HloOpcode::kLt:
CHECK_EQ(operands.size(), 2); CHECK_EQ(operands.size(), 2);
return ElementWiseBinaryOp<bool>( return ElementWiseBinaryOp<bool>(
shape, [](NativeT lhs, NativeT rhs) { return lhs < rhs; }, shape, [](NativeT lhs, NativeT rhs) { return lhs < rhs; },
operands[0]->literal(), operands[1]->literal()); GetEvaluatedLiteralFor(operands[0]),
GetEvaluatedLiteralFor(operands[1]));
case HloOpcode::kNe: case HloOpcode::kNe:
CHECK_EQ(operands.size(), 2); CHECK_EQ(operands.size(), 2);
return ElementWiseBinaryOp<bool>( return ElementWiseBinaryOp<bool>(
shape, [](NativeT lhs, NativeT rhs) { return lhs != rhs; }, shape, [](NativeT lhs, NativeT rhs) { return lhs != rhs; },
operands[0]->literal(), operands[1]->literal()); GetEvaluatedLiteralFor(operands[0]),
GetEvaluatedLiteralFor(operands[1]));
case HloOpcode::kMaximum: case HloOpcode::kMaximum:
CHECK_EQ(operands.size(), 2); CHECK_EQ(operands.size(), 2);
return ElementWiseBinaryOp<NativeT>( return ElementWiseBinaryOp<NativeT>(
shape, [](NativeT lhs, NativeT rhs) { return std::max(lhs, rhs); }, shape, [](NativeT lhs, NativeT rhs) { return std::max(lhs, rhs); },
operands[0]->literal(), operands[1]->literal()); GetEvaluatedLiteralFor(operands[0]),
GetEvaluatedLiteralFor(operands[1]));
case HloOpcode::kMinimum: case HloOpcode::kMinimum:
CHECK_EQ(operands.size(), 2); CHECK_EQ(operands.size(), 2);
return ElementWiseBinaryOp<NativeT>( return ElementWiseBinaryOp<NativeT>(
shape, [](NativeT lhs, NativeT rhs) { return std::min(lhs, rhs); }, shape, [](NativeT lhs, NativeT rhs) { return std::min(lhs, rhs); },
operands[0]->literal(), operands[1]->literal()); GetEvaluatedLiteralFor(operands[0]),
GetEvaluatedLiteralFor(operands[1]));
case HloOpcode::kPower: case HloOpcode::kPower:
CHECK_EQ(operands.size(), 2); CHECK_EQ(operands.size(), 2);
return ElementWiseBinaryOp<NativeT>( return ElementWiseBinaryOp<NativeT>(
shape, [](NativeT lhs, NativeT rhs) { return std::pow(lhs, rhs); }, shape, [](NativeT lhs, NativeT rhs) { return std::pow(lhs, rhs); },
operands[0]->literal(), operands[1]->literal()); GetEvaluatedLiteralFor(operands[0]),
GetEvaluatedLiteralFor(operands[1]));
case HloOpcode::kRemainder: case HloOpcode::kRemainder:
CHECK_EQ(operands.size(), 2); CHECK_EQ(operands.size(), 2);
return ElementWiseBinaryOp<NativeT>( return ElementWiseBinaryOp<NativeT>(
shape, shape,
[](NativeT lhs, NativeT rhs) { return std::remainder(lhs, rhs); }, [](NativeT lhs, NativeT rhs) { return std::remainder(lhs, rhs); },
operands[0]->literal(), operands[1]->literal()); GetEvaluatedLiteralFor(operands[0]),
GetEvaluatedLiteralFor(operands[1]));
case HloOpcode::kLogicalAnd: case HloOpcode::kLogicalAnd:
CHECK_EQ(operands.size(), 2); CHECK_EQ(operands.size(), 2);
return ElementWiseBinaryOp<NativeT>( return ElementWiseBinaryOp<NativeT>(
shape, [](NativeT lhs, NativeT rhs) { return lhs && rhs; }, shape, [](NativeT lhs, NativeT rhs) { return lhs && rhs; },
operands[0]->literal(), operands[1]->literal()); GetEvaluatedLiteralFor(operands[0]),
GetEvaluatedLiteralFor(operands[1]));
case HloOpcode::kLogicalOr: case HloOpcode::kLogicalOr:
CHECK_EQ(operands.size(), 2); CHECK_EQ(operands.size(), 2);
return ElementWiseBinaryOp<NativeT>( return ElementWiseBinaryOp<NativeT>(
shape, [](NativeT lhs, NativeT rhs) { return lhs || rhs; }, shape, [](NativeT lhs, NativeT rhs) { return lhs || rhs; },
operands[0]->literal(), operands[1]->literal()); GetEvaluatedLiteralFor(operands[0]),
GetEvaluatedLiteralFor(operands[1]));
// Ternary element-wise ops. // Ternary element-wise ops.
//
case HloOpcode::kClamp: { case HloOpcode::kClamp: {
CHECK_EQ(operands.size(), 3); CHECK_EQ(operands.size(), 3);
std::function<NativeT(NativeT, NativeT, NativeT)> clamp_op = std::function<NativeT(NativeT, NativeT, NativeT)> clamp_op =
@ -297,8 +411,9 @@ StatusOr<std::unique_ptr<Literal>> EvaluateOpForLiteralInternal(
return std::max(low, std::min(value, high)); return std::max(low, std::min(value, high));
}; };
return ElementWiseTernaryOp<NativeT, NativeT, NativeT, NativeT>( return ElementWiseTernaryOp<NativeT, NativeT, NativeT, NativeT>(
shape, std::move(clamp_op), operands[0]->literal(), shape, std::move(clamp_op), GetEvaluatedLiteralFor(operands[0]),
operands[1]->literal(), operands[2]->literal()); GetEvaluatedLiteralFor(operands[1]),
GetEvaluatedLiteralFor(operands[2]));
} break; } break;
case HloOpcode::kSelect: { case HloOpcode::kSelect: {
CHECK_EQ(operands.size(), 3); CHECK_EQ(operands.size(), 3);
@ -311,8 +426,9 @@ StatusOr<std::unique_ptr<Literal>> EvaluateOpForLiteralInternal(
return on_false; return on_false;
}; };
return ElementWiseTernaryOp<NativeT, bool, NativeT, NativeT>( return ElementWiseTernaryOp<NativeT, bool, NativeT, NativeT>(
shape, std::move(select_op), operands[0]->literal(), shape, std::move(select_op), GetEvaluatedLiteralFor(operands[0]),
operands[1]->literal(), operands[2]->literal()); GetEvaluatedLiteralFor(operands[1]),
GetEvaluatedLiteralFor(operands[2]));
} break; } break;
default: default:
return Unimplemented("unhandled HLO ops for HloEvaluator: %s.", return Unimplemented("unhandled HLO ops for HloEvaluator: %s.",
@ -320,48 +436,4 @@ StatusOr<std::unique_ptr<Literal>> EvaluateOpForLiteralInternal(
} }
} }
} // namespace
/* static */ StatusOr<std::unique_ptr<Literal>>
HloEvaluator::EvaluateOpForLiteral(HloInstruction* instruction) {
DCHECK(hlo_query::AllOperandsAreConstants(*instruction));
Shape shape = instruction->shape();
TF_CHECK_OK(ShapeUtil::ValidateShape(shape));
// REVIEW QUESTION: other than a few operations, do we need to handle the
// general case of operands being of different types in the context of the
// evaluator?
switch (shape.element_type()) {
case PRED:
return EvaluateOpForLiteralInternal<bool>(instruction);
case U8:
return EvaluateOpForLiteralInternal<uint8>(instruction);
case U16:
LOG(FATAL) << "U16/uint16 is unimplemented.";
case U32:
return EvaluateOpForLiteralInternal<uint32>(instruction);
case U64:
return EvaluateOpForLiteralInternal<uint64>(instruction);
case S8:
return EvaluateOpForLiteralInternal<int8>(instruction);
case S16:
LOG(FATAL) << "S16/int16 is unimplemented.";
case S32:
return EvaluateOpForLiteralInternal<int32>(instruction);
case S64:
return EvaluateOpForLiteralInternal<int64>(instruction);
case F16:
LOG(FATAL) << "F16 is unimplemented.";
case F32:
return EvaluateOpForLiteralInternal<float>(instruction);
case F64:
return EvaluateOpForLiteralInternal<double>(instruction);
default:
return Unimplemented("unhandled primitive type: %s.",
PrimitiveType_Name(shape.element_type()).c_str());
}
}
} // namespace xla } // namespace xla

View File

@ -18,22 +18,89 @@ limitations under the License.
#include <memory> #include <memory>
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/macros.h"
namespace xla { namespace xla {
// Responsible for evaluating a HLO instruction with constant operands. // Responsible for evaluating HLO and obtain literal as the evaluation results.
class HloEvaluator { //
// This class is not thread-safe.
class HloEvaluator : public DfsHloVisitorWithDefault {
public: public:
// Evaluates a single HLO instruction for constants and return the result as a HloEvaluator() {}
// Literal. ~HloEvaluator() override {}
// Precondition: all operands of the instruction are constants, instruction is
// valid with corresponding number of operands for the given operator. // Evaluates a HLO computation and an array of pointers to literals.
// Return the evaluated result as literal if successful.
// Precondition: argument literals are in post-order corresponding to the
// input instruction's parameters.
StatusOr<std::unique_ptr<Literal>> Evaluate(
HloComputation* computation,
tensorflow::gtl::ArraySlice<const Literal*> arg_literals);
// Evaluates a single HLO instruction and an array of pointers to literals.
// Return the evaluated result as literal if successful.
// Precondition:
// 1. argument literals are in post-order corresponding to the input
// instruction's parameters.
// 2. the instruction's operands must be of either Parameter or Constant type.
// TODO(b/35950897): implement more ops other than element-wise ops. // TODO(b/35950897): implement more ops other than element-wise ops.
static StatusOr<std::unique_ptr<Literal>> EvaluateOpForLiteral( // TODO(b/35950897): handle broadcasts.
StatusOr<std::unique_ptr<Literal>> Evaluate(
HloInstruction* instruction,
tensorflow::gtl::ArraySlice<const Literal*> arg_literals);
protected:
// The following methods implement the DfsHloVisitor interface.
//
// DefaultAction here handles all non-specificialized (i.e., instruction
// without corresponding Handle* method) instructions.
// TODO(b/35950897): it's likely better to refactor the switches here and push
// up the switch to templated methods instead, likely at DfsHloVisitor level.
Status DefaultAction(HloInstruction* hlo_instruction) override;
Status HandleParameter(HloInstruction* parameter) override;
Status HandleConstant(HloInstruction* constant,
const Literal& literal) override;
private:
// Evaluates a single HLO instruction return the result as a Literal if
// successful. A Status will be returned on error.
StatusOr<std::unique_ptr<Literal>> EvaluateBasedOnType(
HloInstruction* instruction); HloInstruction* instruction);
// Evaluates an element-wise HLO instruction that has the same output literal
// type as the operands' types.
template <typename NativeT>
StatusOr<std::unique_ptr<Literal>> EvaluateSameTypedElementwise(
HloInstruction* instruction);
// Returns the already-evaluated literal result for the instruction.
// Crash with log if the given instruction has not been evaluated previously.
const Literal& GetEvaluatedLiteralFor(const HloInstruction* hlo) {
auto it = evaluated_.find(hlo);
CHECK(it != evaluated_.end())
<< "could not find evaluated value for: " << hlo->ToString();
return *(it->second);
}
// Tracks the HLO instruciton and its evaluated literal result.
tensorflow::gtl::FlatMap<const HloInstruction*, std::unique_ptr<Literal>>
evaluated_;
// Stores input literals, assuming they are in post-order. Literals are not
// owned by this class, and they must outlive the lifetime of the instance of
// this class.
tensorflow::gtl::ArraySlice<const Literal*> arg_literals_;
TF_DISALLOW_COPY_AND_ASSIGN(HloEvaluator);
}; };
} // namespace xla } // namespace xla

View File

@ -14,10 +14,13 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
#include <memory>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector>
#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/statusor.h"
@ -29,9 +32,16 @@ limitations under the License.
namespace xla { namespace xla {
namespace { namespace {
class HloEvaluatorTest : public ::testing::Test {
protected:
HloEvaluatorTest() { evaluator_ = MakeUnique<HloEvaluator>(); }
std::unique_ptr<HloEvaluator> evaluator_;
};
// Verifies that HloEvaluator evaluates a HLO instruction that performs clamp // Verifies that HloEvaluator evaluates a HLO instruction that performs clamp
// with 3 operands. // with 3 operands.
TEST(HloEvaluatorTest, DoesClamp) { TEST_F(HloEvaluatorTest, DoesClamp) {
auto low = LiteralUtil::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}}); auto low = LiteralUtil::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}});
auto high = LiteralUtil::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}}); auto high = LiteralUtil::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}});
auto value = LiteralUtil::CreateR2<float>({{0.f, 5.f}, {0.f, 4.f}}); auto value = LiteralUtil::CreateR2<float>({{0.f, 5.f}, {0.f, 4.f}});
@ -44,7 +54,7 @@ TEST(HloEvaluatorTest, DoesClamp) {
shape, HloOpcode::kClamp, c1.get(), c2.get(), c3.get()); shape, HloOpcode::kClamp, c1.get(), c2.get(), c3.get());
std::unique_ptr<Literal> result = std::unique_ptr<Literal> result =
HloEvaluator::EvaluateOpForLiteral(instruction.get()).ConsumeValueOrDie(); evaluator_->Evaluate(instruction.get(), {}).ConsumeValueOrDie();
auto expected = LiteralUtil::CreateR2<float>({{0, 4}, {2, 4}}); auto expected = LiteralUtil::CreateR2<float>({{0, 4}, {2, 4}});
@ -53,7 +63,7 @@ TEST(HloEvaluatorTest, DoesClamp) {
// Verifies that HloEvaluator evaluates a HLO instruction that performs select // Verifies that HloEvaluator evaluates a HLO instruction that performs select
// with 3 operands. // with 3 operands.
TEST(HloEvaluatorTest, DoesSelect) { TEST_F(HloEvaluatorTest, DoesSelect) {
auto pred = LiteralUtil::CreateR2<bool>({{true, false}, {false, true}}); auto pred = LiteralUtil::CreateR2<bool>({{true, false}, {false, true}});
auto on_true = LiteralUtil::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}}); auto on_true = LiteralUtil::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}});
auto on_false = LiteralUtil::CreateR2<float>({{0.f, 5.f}, {0.f, 4.f}}); auto on_false = LiteralUtil::CreateR2<float>({{0.f, 5.f}, {0.f, 4.f}});
@ -66,7 +76,7 @@ TEST(HloEvaluatorTest, DoesSelect) {
shape, HloOpcode::kSelect, c1.get(), c2.get(), c3.get()); shape, HloOpcode::kSelect, c1.get(), c2.get(), c3.get());
std::unique_ptr<Literal> result = std::unique_ptr<Literal> result =
HloEvaluator::EvaluateOpForLiteral(instruction.get()).ConsumeValueOrDie(); evaluator_->Evaluate(instruction.get(), {}).ConsumeValueOrDie();
auto expected = LiteralUtil::CreateR2<float>({{2, 5}, {0, 4}}); auto expected = LiteralUtil::CreateR2<float>({{2, 5}, {0, 4}});
@ -75,7 +85,7 @@ TEST(HloEvaluatorTest, DoesSelect) {
// Verifies that HloEvaluator evaluates a HLO instruction that performs // Verifies that HloEvaluator evaluates a HLO instruction that performs
// element-wise addition with 2 operands. // element-wise addition with 2 operands.
TEST(HloEvaluatorTest, DoesAdd) { TEST_F(HloEvaluatorTest, DoesAdd) {
auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}}); auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}}); auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
@ -86,7 +96,7 @@ TEST(HloEvaluatorTest, DoesAdd) {
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, c1.get(), c2.get()); HloInstruction::CreateBinary(shape, HloOpcode::kAdd, c1.get(), c2.get());
std::unique_ptr<Literal> result = std::unique_ptr<Literal> result =
HloEvaluator::EvaluateOpForLiteral(instruction.get()).ConsumeValueOrDie(); evaluator_->Evaluate(instruction.get(), {}).ConsumeValueOrDie();
auto expected = LiteralUtil::CreateR2<int64>({{3, 4}, {-96, 8}}); auto expected = LiteralUtil::CreateR2<int64>({{3, 4}, {-96, 8}});
@ -95,7 +105,7 @@ TEST(HloEvaluatorTest, DoesAdd) {
// Verifies that HloEvaluator evaluates a HLO instruction that performs // Verifies that HloEvaluator evaluates a HLO instruction that performs
// element-wise divide with 2 operands. // element-wise divide with 2 operands.
TEST(HloEvaluatorTest, DoesDivide) { TEST_F(HloEvaluatorTest, DoesDivide) {
auto lhs_s64 = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}}); auto lhs_s64 = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
auto rhs_s64 = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}}); auto rhs_s64 = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
@ -106,7 +116,7 @@ TEST(HloEvaluatorTest, DoesDivide) {
c1_s64.get(), c2_s64.get()); c1_s64.get(), c2_s64.get());
std::unique_ptr<Literal> result = std::unique_ptr<Literal> result =
HloEvaluator::EvaluateOpForLiteral(instruction.get()).ConsumeValueOrDie(); evaluator_->Evaluate(instruction.get(), {}).ConsumeValueOrDie();
auto expected = LiteralUtil::CreateR2<int64>({{0, 0}, {-25, 1}}); auto expected = LiteralUtil::CreateR2<int64>({{0, 0}, {-25, 1}});
@ -121,8 +131,7 @@ TEST(HloEvaluatorTest, DoesDivide) {
instruction = HloInstruction::CreateBinary(shape_f64, HloOpcode::kDivide, instruction = HloInstruction::CreateBinary(shape_f64, HloOpcode::kDivide,
c1_f64.get(), c2_f64.get()); c1_f64.get(), c2_f64.get());
result = result = evaluator_->Evaluate(instruction.get(), {}).ConsumeValueOrDie();
HloEvaluator::EvaluateOpForLiteral(instruction.get()).ConsumeValueOrDie();
expected = expected =
LiteralUtil::CreateR2<double>({{0.45454545454545453, 0}, {-25, 1}}); LiteralUtil::CreateR2<double>({{0.45454545454545453, 0}, {-25, 1}});
@ -132,21 +141,51 @@ TEST(HloEvaluatorTest, DoesDivide) {
// Verifies that HloEvaluator evaluates a HLO instruction that performs // Verifies that HloEvaluator evaluates a HLO instruction that performs
// element-wise abs op with 1 operand. // element-wise abs op with 1 operand.
TEST(HloEvaluatorTest, DoesAbs) { TEST_F(HloEvaluatorTest, DoesAbs) {
auto operand = LiteralUtil::CreateR2<int64>({{1, -20}, {-100, 4}}); auto operand = LiteralUtil::CreateR2<int64>({{1, -20}, {-100, 4}});
Shape shape = ShapeUtil::MakeShape(S64, {2, 2}); Shape shape = ShapeUtil::MakeShape(S64, {2, 2});
auto c1 = HloInstruction::CreateConstant(std::move(operand)); auto c1 = HloInstruction::CreateConstant(std::move(operand));
auto instruction = auto instruction =
HloInstruction::CreateUnary(shape, HloOpcode::kAbs, c1.get()); HloInstruction::CreateUnary(shape, HloOpcode::kAbs, c1.get());
std::unique_ptr<Literal> result = std::unique_ptr<Literal> result =
HloEvaluator::EvaluateOpForLiteral(instruction.get()).ConsumeValueOrDie(); evaluator_->Evaluate(instruction.get(), {}).ConsumeValueOrDie();
auto expected = LiteralUtil::CreateR2<int64>({{1, 20}, {100, 4}}); auto expected = LiteralUtil::CreateR2<int64>({{1, 20}, {100, 4}});
EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); EXPECT_TRUE(LiteralUtil::Equal(*result, *expected));
} }
// Verifies that HloEvaluator evaluates a HLO Computation with non-parameter nor
// constant operands.
TEST_F(HloEvaluatorTest, DoesTraveseInstructions) {
HloComputation::Builder builder(
::testing::UnitTest::GetInstance()->current_test_info()->name());
auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
auto rhs2 = LiteralUtil::CreateR2<int64>({{1, -20}, {-100, 4}});
std::vector<const Literal*> args = {lhs.get(), rhs.get(), rhs2.get()};
Shape shape = ShapeUtil::MakeShape(S64, {2, 2});
auto param_lhs = HloInstruction::CreateParameter(0, shape, "lhs");
auto param_rhs = HloInstruction::CreateParameter(1, shape, "rhs");
auto lhs_instruction = HloInstruction::CreateBinary(
shape, HloOpcode::kAdd, param_lhs.get(), param_rhs.get());
auto param_rhs2 = HloInstruction::CreateParameter(2, shape, "rhs2");
auto root_instruction = HloInstruction::CreateBinary(
shape, HloOpcode::kAdd, lhs_instruction.get(), param_rhs2.get());
builder.AddInstruction(std::move(root_instruction));
std::unique_ptr<Literal> result =
evaluator_->Evaluate(builder.Build().get(), args).ConsumeValueOrDie();
auto expected = LiteralUtil::CreateR2<int64>({{4, -16}, {-196, 12}});
EXPECT_TRUE(LiteralUtil::Equal(*result, *expected));
}
} // namespace } // namespace
} // namespace xla } // namespace xla

View File

@ -32,6 +32,16 @@ bool IsConstantR0F32(HloInstruction* instruction, float* out) {
return false; return false;
} }
bool AllOperandsAreParametersOrConstants(const HloInstruction& instruction) {
for (const auto& operand : instruction.operands()) {
if (operand->opcode() != HloOpcode::kParameter &&
operand->opcode() != HloOpcode::kConstant) {
return false;
}
}
return true;
}
bool AllOperandsAreParameters(const HloInstruction& instruction) { bool AllOperandsAreParameters(const HloInstruction& instruction) {
for (const auto& operand : instruction.operands()) { for (const auto& operand : instruction.operands()) {
if (operand->opcode() != HloOpcode::kParameter) { if (operand->opcode() != HloOpcode::kParameter) {

View File

@ -28,6 +28,10 @@ namespace hlo_query {
// Precondition: out != nullptr // Precondition: out != nullptr
bool IsConstantR0F32(HloInstruction* instruction, float* out); bool IsConstantR0F32(HloInstruction* instruction, float* out);
// Returns whether all of an instruction's operands are of the types constants
// and parameters.
bool AllOperandsAreParametersOrConstants(const HloInstruction& instruction);
// Returns whether all of an instruction's operands are parameters. // Returns whether all of an instruction's operands are parameters.
bool AllOperandsAreParameters(const HloInstruction& instruction); bool AllOperandsAreParameters(const HloInstruction& instruction);

View File

@ -88,12 +88,18 @@ const string& HloTfGraphBuilder::GetNodeNameForInstruction(
if (ContainsKey(instruction_to_node_name_, instruction)) { if (ContainsKey(instruction_to_node_name_, instruction)) {
return instruction_to_node_name_[instruction]; return instruction_to_node_name_[instruction];
} }
string node_name;
// If an instruction is fused, put it in the subgraph of the fusion; // If an instruction is fused, put it in the subgraph of the fusion;
// otherwise, put it in the computation subgraph. // otherwise, put it in the computation subgraph.
string node_name = if (instruction->IsFused()) {
instruction->IsFused() node_name = GetNodeNameForInstruction(instruction->fusion_instruction());
? GetNodeNameForInstruction(instruction->fusion_instruction()) } else {
: instruction->parent()->name(); node_name = instruction->parent()->name();
if (!instruction->metadata().op_name().empty()) {
// Always make computations contain TF ops but not the other way around.
StrAppend(&node_name, "/", instruction->metadata().op_name());
}
}
string instruction_name = instruction->name(); string instruction_name = instruction->name();
if (instruction->opcode() == HloOpcode::kParameter) { if (instruction->opcode() == HloOpcode::kParameter) {
StrAppend(&instruction_name, ".", instruction->parameter_number()); StrAppend(&instruction_name, ".", instruction->parameter_number());

View File

@ -137,6 +137,28 @@ TEST_F(HloTfGraphBuilderTest, GreaterThanOrEqualTo) {
EXPECT_EQ(graph_def.node(2).op(), "HloGreaterThanOrEqualTo"); EXPECT_EQ(graph_def.node(2).op(), "HloGreaterThanOrEqualTo");
} }
TEST_F(HloTfGraphBuilderTest, IncorparateTfOpsStructure) {
auto builder = HloComputation::Builder("GE");
auto param_1 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32_, "param0"));
auto param_2 = builder.AddInstruction(
HloInstruction::CreateParameter(1, r0f32_, "param1"));
auto ge = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kGe, param_1, param_2));
OpMetadata metadata;
metadata.set_op_name("x/y");
metadata.set_op_type("Y");
ge->set_metadata(metadata);
TF_CHECK_OK(generator_.AddComputation(*builder.Build()));
GraphDef graph_def = generator_.GetGraphDef();
EXPECT_EQ(graph_def.node_size(), 3);
EXPECT_EQ(graph_def.node(0).name(), "GE/param0.0");
EXPECT_EQ(graph_def.node(1).name(), "GE/param1.1");
EXPECT_EQ(graph_def.node(2).input_size(), 2);
EXPECT_EQ(graph_def.node(2).name(), "GE/x/y/greater-than-or-equal-to");
EXPECT_EQ(graph_def.node(2).op(), "HloGreaterThanOrEqualTo");
}
TEST_F(HloTfGraphBuilderTest, EmbeddedComputationsDiamond) { TEST_F(HloTfGraphBuilderTest, EmbeddedComputationsDiamond) {
// Create computations with a diamond-shaped callgraph. // Create computations with a diamond-shaped callgraph.
auto negate_computation = CreateNegateComputation(); auto negate_computation = CreateNegateComputation();

View File

@ -193,38 +193,6 @@ cuda_py_test(
tags = ["notap"], # http://b/30441813 tags = ["notap"], # http://b/30441813
) )
cuda_py_test(
name = "bernoulli_test",
size = "small",
srcs = ["python/kernel_tests/bernoulli_test.py"],
additional_deps = [
":distributions_py",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
)
cuda_py_test(
name = "beta_test",
size = "small",
srcs = ["python/kernel_tests/beta_test.py"],
additional_deps = [
":distributions_py",
"//third_party/py/numpy",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn_ops",
"//tensorflow/python:platform_test",
],
)
cuda_py_test( cuda_py_test(
name = "binomial_test", name = "binomial_test",
size = "small", size = "small",
@ -238,24 +206,6 @@ cuda_py_test(
], ],
) )
cuda_py_test(
name = "categorical_test",
size = "small",
srcs = ["python/kernel_tests/categorical_test.py"],
additional_deps = [
":distributions_py",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:random_ops",
],
)
cuda_py_test( cuda_py_test(
name = "chi2_test", name = "chi2_test",
srcs = ["python/kernel_tests/chi2_test.py"], srcs = ["python/kernel_tests/chi2_test.py"],
@ -287,66 +237,6 @@ cuda_py_test(
], ],
) )
cuda_py_test(
name = "dirichlet_test",
size = "small",
srcs = ["python/kernel_tests/dirichlet_test.py"],
additional_deps = [
":distributions_py",
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
cuda_py_test(
name = "dirichlet_multinomial_test",
size = "medium",
srcs = ["python/kernel_tests/dirichlet_multinomial_test.py"],
additional_deps = [
":distributions_py",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
)
cuda_py_test(
name = "exponential_test",
srcs = ["python/kernel_tests/exponential_test.py"],
additional_deps = [
":distributions_py",
"//third_party/py/numpy",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:nn_ops",
"//tensorflow/python:platform_test",
],
)
cuda_py_test(
name = "gamma_test",
srcs = ["python/kernel_tests/gamma_test.py"],
additional_deps = [
":distributions_py",
"//third_party/py/numpy",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:nn_ops",
"//tensorflow/python:platform_test",
],
)
cuda_py_test( cuda_py_test(
name = "geometric_test", name = "geometric_test",
size = "small", size = "small",
@ -378,36 +268,6 @@ cuda_py_test(
], ],
) )
cuda_py_test(
name = "laplace_test",
srcs = ["python/kernel_tests/laplace_test.py"],
additional_deps = [
":distributions_py",
"//third_party/py/numpy",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:nn_ops",
"//tensorflow/python:platform_test",
],
)
cuda_py_test(
name = "multinomial_test",
srcs = ["python/kernel_tests/multinomial_test.py"],
additional_deps = [
":distributions_py",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
)
cuda_py_test( cuda_py_test(
name = "mvn_diag_test", name = "mvn_diag_test",
size = "small", size = "small",
@ -528,24 +388,6 @@ cuda_py_test(
tags = ["nomsan"], # disable to avoid false positives from scipy. tags = ["nomsan"], # disable to avoid false positives from scipy.
) )
cuda_py_test(
name = "student_t_test",
size = "small",
srcs = ["python/kernel_tests/student_t_test.py"],
additional_deps = [
":distributions_py",
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn_ops",
"//tensorflow/python:platform_test",
],
tags = ["nomsan"], # disable to avoid false positives from scipy.
)
cuda_py_test( cuda_py_test(
name = "vector_student_t_test", name = "vector_student_t_test",
size = "medium", size = "medium",
@ -562,22 +404,6 @@ cuda_py_test(
], ],
) )
cuda_py_test(
name = "uniform_test",
size = "small",
srcs = ["python/kernel_tests/uniform_test.py"],
additional_deps = [
":distributions_py",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
],
)
cuda_py_test( cuda_py_test(
name = "wishart_test", name = "wishart_test",
size = "small", size = "small",

View File

@ -15,74 +15,6 @@
"""Classes representing statistical distributions and ops for working with them. """Classes representing statistical distributions and ops for working with them.
See the @{$python/contrib.distributions} guide. See the @{$python/contrib.distributions} guide.
## Distribution Object
@@ReparameterizationType
@@Distribution
## Individual Distributions
@@Binomial
@@Bernoulli
@@BernoulliWithSigmoidProbs
@@Beta
@@BetaWithSoftplusConcentration
@@Categorical
@@Chi2
@@Chi2WithAbsDf
@@Deterministic
@@VectorDeterministic
@@Exponential
@@ExponentialWithSoftplusRate
@@Gamma
@@GammaWithSoftplusConcentrationRate
@@Geometric
@@InverseGamma
@@InverseGammaWithSoftplusConcentrationRate
@@Laplace
@@LaplaceWithSoftplusScale
@@Logistic
@@NegativeBinomial
@@Normal
@@NormalWithSoftplusScale
@@Poisson
@@StudentT
@@StudentTWithAbsDfSoftplusScale
@@Uniform
@@MultivariateNormalDiag
@@MultivariateNormalTriL
@@MultivariateNormalDiagPlusLowRank
@@MultivariateNormalDiagWithSoftplusScale
@@Dirichlet
@@DirichletMultinomial
@@Multinomial
@@WishartCholesky
@@WishartFull
@@TransformedDistribution
@@QuantizedDistribution
@@Mixture
@@ExpRelaxedOneHotCategorical
@@OneHotCategorical
@@RelaxedBernoulli
@@RelaxedOneHotCategorical
## Kullback-Leibler Divergence
@@kl_divergence
@@RegisterKL
## Helper Functions
@@matrix_diag_transform
@@normal_conjugates_known_scale_posterior
@@normal_conjugates_known_scale_predictive
@@softplus_inverse
## Functions for statistics of samples
@@percentile
""" """
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
@ -91,25 +23,16 @@ from __future__ import print_function
# pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member # pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member
from tensorflow.contrib.distributions.python.ops import bijectors from tensorflow.contrib.distributions.python.ops import bijectors
from tensorflow.contrib.distributions.python.ops.bernoulli import *
from tensorflow.contrib.distributions.python.ops.beta import *
from tensorflow.contrib.distributions.python.ops.binomial import * from tensorflow.contrib.distributions.python.ops.binomial import *
from tensorflow.contrib.distributions.python.ops.categorical import *
from tensorflow.contrib.distributions.python.ops.chi2 import * from tensorflow.contrib.distributions.python.ops.chi2 import *
from tensorflow.contrib.distributions.python.ops.conditional_transformed_distribution import * from tensorflow.contrib.distributions.python.ops.conditional_transformed_distribution import *
from tensorflow.contrib.distributions.python.ops.deterministic import * from tensorflow.contrib.distributions.python.ops.deterministic import *
from tensorflow.contrib.distributions.python.ops.dirichlet import *
from tensorflow.contrib.distributions.python.ops.dirichlet_multinomial import *
from tensorflow.contrib.distributions.python.ops.distribution_util import matrix_diag_transform from tensorflow.contrib.distributions.python.ops.distribution_util import matrix_diag_transform
from tensorflow.contrib.distributions.python.ops.distribution_util import softplus_inverse from tensorflow.contrib.distributions.python.ops.distribution_util import softplus_inverse
from tensorflow.contrib.distributions.python.ops.exponential import *
from tensorflow.contrib.distributions.python.ops.gamma import *
from tensorflow.contrib.distributions.python.ops.geometric import * from tensorflow.contrib.distributions.python.ops.geometric import *
from tensorflow.contrib.distributions.python.ops.inverse_gamma import * from tensorflow.contrib.distributions.python.ops.inverse_gamma import *
from tensorflow.contrib.distributions.python.ops.laplace import *
from tensorflow.contrib.distributions.python.ops.logistic import * from tensorflow.contrib.distributions.python.ops.logistic import *
from tensorflow.contrib.distributions.python.ops.mixture import * from tensorflow.contrib.distributions.python.ops.mixture import *
from tensorflow.contrib.distributions.python.ops.multinomial import *
from tensorflow.contrib.distributions.python.ops.mvn_diag import * from tensorflow.contrib.distributions.python.ops.mvn_diag import *
from tensorflow.contrib.distributions.python.ops.mvn_diag_plus_low_rank import * from tensorflow.contrib.distributions.python.ops.mvn_diag_plus_low_rank import *
from tensorflow.contrib.distributions.python.ops.mvn_tril import * from tensorflow.contrib.distributions.python.ops.mvn_tril import *
@ -121,14 +44,23 @@ from tensorflow.contrib.distributions.python.ops.quantized_distribution import *
from tensorflow.contrib.distributions.python.ops.relaxed_bernoulli import * from tensorflow.contrib.distributions.python.ops.relaxed_bernoulli import *
from tensorflow.contrib.distributions.python.ops.relaxed_onehot_categorical import * from tensorflow.contrib.distributions.python.ops.relaxed_onehot_categorical import *
from tensorflow.contrib.distributions.python.ops.sample_stats import * from tensorflow.contrib.distributions.python.ops.sample_stats import *
from tensorflow.contrib.distributions.python.ops.student_t import *
from tensorflow.contrib.distributions.python.ops.transformed_distribution import * from tensorflow.contrib.distributions.python.ops.transformed_distribution import *
from tensorflow.contrib.distributions.python.ops.uniform import *
from tensorflow.contrib.distributions.python.ops.wishart import * from tensorflow.contrib.distributions.python.ops.wishart import *
from tensorflow.python.ops.distributions.bernoulli import *
from tensorflow.python.ops.distributions.beta import *
from tensorflow.python.ops.distributions.categorical import *
from tensorflow.python.ops.distributions.conditional_distribution import * from tensorflow.python.ops.distributions.conditional_distribution import *
from tensorflow.python.ops.distributions.dirichlet import *
from tensorflow.python.ops.distributions.dirichlet_multinomial import *
from tensorflow.python.ops.distributions.distribution import * from tensorflow.python.ops.distributions.distribution import *
from tensorflow.python.ops.distributions.exponential import *
from tensorflow.python.ops.distributions.gamma import *
from tensorflow.python.ops.distributions.kullback_leibler import * from tensorflow.python.ops.distributions.kullback_leibler import *
from tensorflow.python.ops.distributions.laplace import *
from tensorflow.python.ops.distributions.multinomial import *
from tensorflow.python.ops.distributions.normal import * from tensorflow.python.ops.distributions.normal import *
from tensorflow.python.ops.distributions.student_t import *
from tensorflow.python.ops.distributions.uniform import *
# pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member # pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member
@ -140,6 +72,71 @@ _allowed_symbols = [
'ConditionalTransformedDistribution', 'ConditionalTransformedDistribution',
'FULLY_REPARAMETERIZED', 'FULLY_REPARAMETERIZED',
'NOT_REPARAMETERIZED', 'NOT_REPARAMETERIZED',
'Affine',
'AffineLinearOperator',
'Bijector',
'Chain',
'CholeskyOuterProduct',
'Exp',
'Identity',
'Inline',
'Invert',
'PowerTransform',
'SigmoidCentered',
'SoftmaxCentered',
'Softplus',
'ReparameterizationType',
'Distribution',
'Binomial',
'Bernoulli',
'BernoulliWithSigmoidProbs',
'Beta',
'BetaWithSoftplusConcentration',
'Categorical',
'Chi2',
'Chi2WithAbsDf',
'Deterministic',
'VectorDeterministic',
'Exponential',
'ExponentialWithSoftplusRate',
'Gamma',
'GammaWithSoftplusConcentrationRate',
'Geometric',
'InverseGamma',
'InverseGammaWithSoftplusConcentrationRate',
'Laplace',
'LaplaceWithSoftplusScale',
'Logistic',
'NegativeBinomial',
'Normal',
'NormalWithSoftplusScale',
'Poisson',
'StudentT',
'StudentTWithAbsDfSoftplusScale',
'Uniform',
'MultivariateNormalDiag',
'MultivariateNormalTriL',
'MultivariateNormalDiagPlusLowRank',
'MultivariateNormalDiagWithSoftplusScale',
'Dirichlet',
'DirichletMultinomial',
'Multinomial',
'WishartCholesky',
'WishartFull',
'TransformedDistribution',
'QuantizedDistribution',
'Mixture',
'ExpRelaxedOneHotCategorical',
'OneHotCategorical',
'RelaxedBernoulli',
'RelaxedOneHotCategorical',
'kl_divergence',
'RegisterKL',
'matrix_diag_transform',
'normal_conjugates_known_scale_posterior',
'normal_conjugates_known_scale_predictive',
'softplus_inverse',
'percentile'
] ]
remove_undocumented(__name__, _allowed_symbols) remove_undocumented(__name__, _allowed_symbols)

View File

@ -19,11 +19,11 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import bijectors from tensorflow.contrib.distributions.python.ops import bijectors
from tensorflow.contrib.distributions.python.ops import gamma as gamma_lib
from tensorflow.contrib.distributions.python.ops import transformed_distribution as transformed_distribution_lib from tensorflow.contrib.distributions.python.ops import transformed_distribution as transformed_distribution_lib
from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops.distributions import gamma as gamma_lib
from tensorflow.python.platform import test from tensorflow.python.platform import test

View File

@ -19,11 +19,11 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import bijectors from tensorflow.contrib.distributions.python.ops import bijectors
from tensorflow.contrib.distributions.python.ops import gamma as gamma_lib
from tensorflow.contrib.distributions.python.ops import transformed_distribution as transformed_distribution_lib from tensorflow.contrib.distributions.python.ops import transformed_distribution as transformed_distribution_lib
from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops.distributions import gamma as gamma_lib
from tensorflow.python.platform import test from tensorflow.python.platform import test

View File

@ -20,9 +20,9 @@ from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.contrib.distributions.python.ops import uniform as uniform_lib
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import uniform as uniform_lib
def assert_finite(array): def assert_finite(array):

View File

@ -18,11 +18,11 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import gamma
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import gamma
__all__ = [ __all__ = [

View File

@ -20,7 +20,6 @@ from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.contrib.distributions.python.ops import categorical
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util from tensorflow.python.framework import tensor_util
@ -29,6 +28,7 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops from tensorflow.python.ops import nn_ops
from tensorflow.python.ops.distributions import categorical
from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import util as distribution_util from tensorflow.python.ops.distributions import util as distribution_util

View File

@ -19,13 +19,13 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import bijectors from tensorflow.contrib.distributions.python.ops import bijectors
from tensorflow.contrib.distributions.python.ops import student_t
from tensorflow.contrib.distributions.python.ops import transformed_distribution from tensorflow.contrib.distributions.python.ops import transformed_distribution
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops.distributions import student_t
from tensorflow.python.ops.distributions import util as distribution_util from tensorflow.python.ops.distributions import util as distribution_util

View File

@ -921,12 +921,21 @@ def _softmax_cross_entropy_loss(labels, logits, weights=None):
if not labels.dtype.is_integer: if not labels.dtype.is_integer:
raise ValueError("Labels dtype should be integer " raise ValueError("Labels dtype should be integer "
"Instead got %s." % labels.dtype) "Instead got %s." % labels.dtype)
# TODO(ptucker): This will break for dynamic shapes.
# sparse_softmax_cross_entropy_with_logits requires [batch_size] labels. # sparse_softmax_cross_entropy_with_logits requires [batch_size] labels.
is_squeezed_labels = False
# TODO(ptucker): This will break for dynamic shapes.
if len(labels.get_shape()) == 2: if len(labels.get_shape()) == 2:
labels = array_ops.squeeze(labels, squeeze_dims=(1,)) labels = array_ops.squeeze(labels, squeeze_dims=(1,))
is_squeezed_labels = True
loss = nn.sparse_softmax_cross_entropy_with_logits( loss = nn.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=logits, name=name) labels=labels, logits=logits, name=name)
# Restore squeezed dimension, if necessary, so loss matches weights shape.
if is_squeezed_labels:
loss = array_ops.expand_dims(loss, axis=(1,))
return _compute_weighted_loss(loss, weights) return _compute_weighted_loss(loss, weights)

View File

@ -791,7 +791,7 @@ class BinaryClassificationHeadTest(test.TestCase):
[b"0", b"1"], predicted_classes[0]) [b"0", b"1"], predicted_classes[0])
self.assertIn("probabilities", six.iterkeys(predictions_for_serving)) self.assertIn("probabilities", six.iterkeys(predictions_for_serving))
def testBinaryClassificationInferMode_withWightColumn(self): def testBinaryClassificationInferMode_withWeightColumn(self):
n_classes = 2 n_classes = 2
head = head_lib.multi_class_head(n_classes=n_classes, head = head_lib.multi_class_head(n_classes=n_classes,
weight_column_name="label_weight") weight_column_name="label_weight")
@ -951,7 +951,7 @@ class MultiClassHeadTest(test.TestCase):
def setUp(self): def setUp(self):
self._logits = ((1., 0., 0.),) self._logits = ((1., 0., 0.),)
self._labels = (2,) self._labels = ((2,),)
def _expected_eval_metrics(self, expected_loss): def _expected_eval_metrics(self, expected_loss):
return { return {
@ -1131,7 +1131,7 @@ class MultiClassHeadTest(test.TestCase):
_assert_metrics(self, expected_loss, _assert_metrics(self, expected_loss,
expected_eval_metrics, model_fn_ops) expected_eval_metrics, model_fn_ops)
def testMultiClassWithWeight(self): def testMultiClassWithScalarWeight(self):
n_classes = 3 n_classes = 3
head = head_lib.multi_class_head( head = head_lib.multi_class_head(
n_classes=n_classes, n_classes=n_classes,
@ -1154,6 +1154,30 @@ class MultiClassHeadTest(test.TestCase):
_assert_metrics(self, expected_loss * weight, _assert_metrics(self, expected_loss * weight,
self._expected_eval_metrics(expected_loss), model_fn_ops) self._expected_eval_metrics(expected_loss), model_fn_ops)
def testMultiClassWith2DWeight(self):
n_classes = 3
head = head_lib.multi_class_head(
n_classes=n_classes,
weight_column_name="label_weight",
metric_class_ids=range(n_classes))
with ops.Graph().as_default(), session.Session():
weight = .1
weights = ((weight,),)
# logloss: z:label, x:logit
# z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
model_fn_ops = head.create_model_fn_ops(
features={"label_weight": weights},
labels=self._labels,
mode=model_fn.ModeKeys.TRAIN,
train_op_fn=head_lib.no_op_train_fn,
logits=self._logits)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
_assert_summary_tags(self, ["loss"])
expected_loss = 1.5514447
_assert_metrics(self, expected_loss * weight,
self._expected_eval_metrics(expected_loss), model_fn_ops)
def testMultiClassWithCustomLoss(self): def testMultiClassWithCustomLoss(self):
n_classes = 3 n_classes = 3
head = head_lib.multi_class_head( head = head_lib.multi_class_head(

View File

@ -31,6 +31,17 @@ from tensorflow.python.estimator import run_config as core_run_config
from tensorflow.python.training import server_lib from tensorflow.python.training import server_lib
_DEFAULT_UID_WHITE_LIST = [
'tf_random_seed',
'save_summary_steps',
'save_checkpoints_steps',
'save_checkpoints_secs',
'session_config',
'keep_checkpoint_max',
'keep_checkpoint_every_n_hours',
]
class Environment(object): class Environment(object):
# For running general distributed training. # For running general distributed training.
CLOUD = 'cloud' CLOUD = 'cloud'
@ -312,18 +323,29 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig):
return new_copy return new_copy
@experimental @experimental
def uid(self): def uid(self, whitelist=None):
"""Generates a 'Unique Identifier' based on all internal fields. """Generates a 'Unique Identifier' based on all internal fields.
Caller should use the uid string to check `RunConfig` instance integrity Caller should use the uid string to check `RunConfig` instance integrity
in one session use, but should not rely on the implementation details, which in one session use, but should not rely on the implementation details, which
is subject to change. is subject to change.
Args:
whitelist: A list of the string names of the properties uid should not
include. If `None`, defaults to `_DEFAULT_UID_WHITE_LIST`, which
includes most properites user allowes to change.
Returns: Returns:
A uid string. A uid string.
""" """
# TODO(b/33295821): Allows user to specify a whitelist. if whitelist is None:
whitelist = _DEFAULT_UID_WHITE_LIST
state = {k: v for k, v in self.__dict__.items() if not k.startswith('__')} state = {k: v for k, v in self.__dict__.items() if not k.startswith('__')}
# Pop out the keys in whitelist.
for k in whitelist:
state.pop('_' + k, None)
ordered_state = collections.OrderedDict( ordered_state = collections.OrderedDict(
sorted(state.items(), key=lambda t: t[0])) sorted(state.items(), key=lambda t: t[0]))
# For class instance without __repr__, some special cares are required. # For class instance without __repr__, some special cares are required.

View File

@ -257,6 +257,51 @@ class RunConfigTest(test.TestCase):
self.assertNotEqual(expected_uid, new_config.uid()) self.assertNotEqual(expected_uid, new_config.uid())
self.assertEqual(ANOTHER_TEST_DIR, new_config.model_dir) self.assertEqual(ANOTHER_TEST_DIR, new_config.model_dir)
def test_uid_for_whitelist(self):
whitelist = ["model_dir"]
config = run_config_lib.RunConfig(
tf_random_seed=RANDOM_SEED, model_dir=TEST_DIR)
expected_uid = config.uid(whitelist)
self.assertEqual(expected_uid, config.uid(whitelist))
new_config = config.replace(model_dir=ANOTHER_TEST_DIR)
self.assertEqual(TEST_DIR, config.model_dir)
self.assertEqual(expected_uid, new_config.uid(whitelist))
self.assertEqual(ANOTHER_TEST_DIR, new_config.model_dir)
def test_uid_for_default_whitelist(self):
config = run_config_lib.RunConfig(
tf_random_seed=11,
save_summary_steps=12,
save_checkpoints_steps=13,
save_checkpoints_secs=14,
session_config=15,
keep_checkpoint_max=16,
keep_checkpoint_every_n_hours=17)
self.assertEqual(11, config.tf_random_seed)
self.assertEqual(12, config.save_summary_steps)
self.assertEqual(13, config.save_checkpoints_steps)
self.assertEqual(14, config.save_checkpoints_secs)
self.assertEqual(15, config.session_config)
self.assertEqual(16, config.keep_checkpoint_max)
self.assertEqual(17, config.keep_checkpoint_every_n_hours)
new_config = run_config_lib.RunConfig(
tf_random_seed=21,
save_summary_steps=22,
save_checkpoints_steps=23,
save_checkpoints_secs=24,
session_config=25,
keep_checkpoint_max=26,
keep_checkpoint_every_n_hours=27)
self.assertEqual(config.uid(), new_config.uid())
# model_dir is not on the default whitelist.
self.assertNotEqual(config.uid(whitelist=[]),
new_config.uid(whitelist=[]))
new_config = new_config.replace(model_dir=ANOTHER_TEST_DIR)
self.assertNotEqual(config.uid(), new_config.uid())
def test_uid_for_deepcopy(self): def test_uid_for_deepcopy(self):
tf_config = { tf_config = {
"cluster": { "cluster": {

View File

@ -293,8 +293,7 @@ class LearnRunnerRunWithRunConfigTest(test.TestCase):
def _experiment_fn(run_config, hparams): def _experiment_fn(run_config, hparams):
del run_config, hparams # unused. del run_config, hparams # unused.
# Explicitly use a new run_config. # Explicitly use a new run_config.
new_config = run_config_lib.RunConfig( new_config = run_config_lib.RunConfig(model_dir=_MODIR_DIR + "/123")
model_dir=_MODIR_DIR, save_checkpoints_steps=123)
return TestExperiment(config=new_config) return TestExperiment(config=new_config)

View File

@ -22,10 +22,26 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
# pylint: disable=unused-import,wildcard-import # pylint: disable=wildcard-import
from tensorflow.contrib.losses.python import losses
from tensorflow.contrib.losses.python.losses import * from tensorflow.contrib.losses.python.losses import *
# pylint: enable=unused-import,wildcard-import # pylint: enable=wildcard-import
from tensorflow.python.util.all_util import remove_undocumented from tensorflow.python.util.all_util import remove_undocumented
remove_undocumented(__name__, doc_string_modules=[losses])
_allowed_symbols = [
'absolute_difference',
'add_loss',
'hinge_loss',
'compute_weighted_loss',
'cosine_distance',
'get_losses',
'get_regularization_losses',
'get_total_loss',
'log_loss',
'mean_pairwise_squared_error',
'mean_squared_error',
'sigmoid_cross_entropy',
'softmax_cross_entropy',
'sparse_softmax_cross_entropy',
]
remove_undocumented(__name__, _allowed_symbols)

View File

@ -12,127 +12,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""## Loss operations for use in neural networks. """Ops for building neural network losses.
Note: By default all the losses are collected into the `GraphKeys.LOSSES` See @{$python/contrib.losses}.
collection.
All of the loss functions take a pair of predictions and ground truth labels,
from which the loss is computed. It is assumed that the shape of both these
tensors is of the form [batch_size, d1, ... dN] where `batch_size` is the number
of samples in the batch and `d1` ... `dN` are the remaining dimensions.
It is common, when training with multiple loss functions, to adjust the relative
strengths of individual losses. This is performed by rescaling the losses via
a `weight` parameter passed to the loss functions. For example, if we were
training with both log_loss and sum_of_squares_loss, and we wished that the
log_loss penalty be twice as severe as the sum_of_squares_loss, we would
implement this as:
# Explicitely set the weight.
tf.contrib.losses.log(predictions, labels, weight=2.0)
# Uses default weight of 1.0
tf.contrib.losses.sum_of_squares(predictions, labels)
# All the losses are collected into the `GraphKeys.LOSSES` collection.
losses = tf.get_collection(tf.GraphKeys.LOSSES)
While specifying a scalar loss rescales the loss over the entire batch,
we sometimes want to rescale the loss per batch sample. For example, if we have
certain examples that matter more to us to get correctly, we might want to have
a higher loss that other samples whose mistakes matter less. In this case, we
can provide a weight vector of length `batch_size` which results in the loss
for each sample in the batch being scaled by the corresponding weight element.
For example, consider the case of a classification problem where we want to
maximize our accuracy but we especially interested in obtaining high accuracy
for a specific class:
inputs, labels = LoadData(batch_size=3)
logits = MyModelPredictions(inputs)
# Ensures that the loss for examples whose ground truth class is `3` is 5x
# higher than the loss for all other examples.
weight = tf.multiply(4, tf.cast(tf.equal(labels, 3), tf.float32)) + 1
onehot_labels = tf.one_hot(labels, num_classes=5)
tf.contrib.losses.softmax_cross_entropy(logits, onehot_labels, weight=weight)
Finally, in certain cases, we may want to specify a different loss for every
single measurable value. For example, if we are performing per-pixel depth
prediction, or per-pixel denoising, a single batch sample has P values where P
is the number of pixels in the image. For many losses, the number of measurable
values matches the number of elements in the predictions and labels tensors.
For others, such as softmax_cross_entropy and cosine_distance, the
loss functions reduces the dimensions of the inputs to produces a tensor of
losses for each measurable value. For example, softmax_cross_entropy takes as
input predictions and labels of dimension [batch_size, num_classes] but the
number of measurable values is [batch_size]. Consequently, when passing a weight
tensor to specify a different loss for every measurable value, the dimension of
the tensor will depend on the loss being used.
For a concrete example, consider the case of per-pixel depth prediction where
certain ground truth depth values are missing (due to sensor noise in the
capture process). In this case, we want to assign zero weight to losses for
these predictions.
# 'depths' that are missing have a value of 0:
images, depths = LoadData(...)
predictions = MyModelPredictions(images)
weight = tf.cast(tf.greater(depths, 0), tf.float32)
loss = tf.contrib.losses.sum_of_squares(predictions, depths, weight)
Note that when using weights for the losses, the final average is computed
by rescaling the losses by the weights and then dividing by the total number of
non-zero samples. For an arbitrary set of weights, this may not necessarily
produce a weighted average. Instead, it simply and transparently rescales the
per-element losses before averaging over the number of observations. For example
if the losses computed by the loss function is an array [4, 1, 2, 3] and the
weights are an array [1, 0.5, 3, 9], then the average loss is:
(4*1 + 1*0.5 + 2*3 + 3*9) / 4
However, with a single loss function and an arbitrary set of weights, one can
still easily create a loss function such that the resulting loss is a
weighted average over the individual prediction errors:
images, labels = LoadData(...)
predictions = MyModelPredictions(images)
weight = MyComplicatedWeightingFunction(labels)
weight = tf.div(weight, tf.size(weight))
loss = tf.contrib.losses.sum_of_squares(predictions, depths, weight)
@@absolute_difference
@@add_loss
@@hinge_loss
@@compute_weighted_loss
@@cosine_distance
@@get_losses
@@get_regularization_losses
@@get_total_loss
@@log_loss
@@mean_pairwise_squared_error
@@mean_squared_error
@@sigmoid_cross_entropy
@@softmax_cross_entropy
@@sparse_softmax_cross_entropy
The following are deprecated in favor of `mean_pairwise_squared_error` and
`mean_squared_error`.
@@sum_of_pairwise_squares
@@sum_of_squares
""" """
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
# pylint: disable=unused-import,wildcard-import # pylint: disable=wildcard-import
from tensorflow.contrib.losses.python.losses.loss_ops import * from tensorflow.contrib.losses.python.losses.loss_ops import *
from tensorflow.python.util.all_util import make_all # pylint: enable=wildcard-import
# pylint: enable=unused-import,wildcard-import
__all__ = make_all(__name__)

View File

@ -47,7 +47,6 @@ GRAPH_TRANSFER_SRCS := \
tensorflow/cc/framework/scope.cc \ tensorflow/cc/framework/scope.cc \
tensorflow/cc/framework/ops.cc \ tensorflow/cc/framework/ops.cc \
tensorflow/cc/ops/const_op.cc \ tensorflow/cc/ops/const_op.cc \
tensorflow/core/kernels/function_ops.cc \
tensorflow/core/kernels/hexagon/graph_transfer_utils.cc \ tensorflow/core/kernels/hexagon/graph_transfer_utils.cc \
tensorflow/core/kernels/hexagon/graph_transferer.cc \ tensorflow/core/kernels/hexagon/graph_transferer.cc \
tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc \ tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc \

View File

@ -16,36 +16,6 @@
"""Ops for building neural network seq2seq decoders and losses. """Ops for building neural network seq2seq decoders and losses.
See the @{$python/contrib.seq2seq} guide. See the @{$python/contrib.seq2seq} guide.
@@Decoder
@@dynamic_decode
@@BasicDecoderOutput
@@BasicDecoder
@@BeamSearchDecoderOutput
@@BeamSearchDecoderState
@@BeamSearchDecoder
@@FinalBeamSearchDecoderOutput
@@Helper
@@CustomHelper
@@GreedyEmbeddingHelper
@@ScheduledEmbeddingTrainingHelper
@@ScheduledOutputTrainingHelper
@@TrainingHelper
@@BahdanauAttention
@@LuongAttention
@@hardmax
@@AttentionWrapperState
@@AttentionWrapper
@@gather_tree
@@tile_batch
""" """
from __future__ import absolute_import from __future__ import absolute_import
@ -63,6 +33,30 @@ from tensorflow.contrib.seq2seq.python.ops.loss import *
from tensorflow.python.util.all_util import remove_undocumented from tensorflow.python.util.all_util import remove_undocumented
# pylint: enable=unused-import,widcard-import,line-too-long # pylint: enable=unused-import,widcard-import,line-too-long
_allowed_symbols = ["sequence_loss"] _allowed_symbols = [
"sequence_loss",
"Decoder",
"dynamic_decode",
"BasicDecoder",
"BasicDecoderOutput",
"BeamSearchDecoder",
"BeamSearchDecoderOutput",
"BeamSearchDecoderState",
"Helper",
"CustomHelper",
"FinalBeamSearchDecoderOutput",
"gather_tree",
"GreedyEmbeddingHelper",
"ScheduledEmbeddingTrainingHelper",
"ScheduledOutputTrainingHelper",
"TrainingHelper",
"BahdanauAttention",
"LuongAttention",
"hardmax",
"AttentionWrapperState",
"AttentionWrapper",
"AttentionMechanism",
"tile_batch"]
remove_undocumented(__name__, _allowed_symbols) remove_undocumented(__name__, _allowed_symbols)

View File

@ -39,6 +39,7 @@ from tensorflow.python.util import nest
__all__ = [ __all__ = [
"AttentionMechanism",
"AttentionWrapper", "AttentionWrapper",
"AttentionWrapperState", "AttentionWrapperState",
"LuongAttention", "LuongAttention",

View File

@ -23,8 +23,6 @@ import abc
import six import six
from tensorflow.contrib.distributions.python.ops import bernoulli
from tensorflow.contrib.distributions.python.ops import categorical
from tensorflow.contrib.seq2seq.python.ops import decoder from tensorflow.contrib.seq2seq.python.ops import decoder
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
@ -35,6 +33,8 @@ from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops from tensorflow.python.ops import random_ops
from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops.distributions import bernoulli
from tensorflow.python.ops.distributions import categorical
from tensorflow.python.util import nest from tensorflow.python.util import nest
__all__ = [ __all__ = [

View File

@ -88,10 +88,7 @@ Status ShapeRefiner::AddNode(const Node* node) {
} }
// This needs to be filled in with real data in a second pass. // This needs to be filled in with real data in a second pass.
std::vector<const Tensor*> input_tensors(node->num_inputs()); std::vector<const Tensor*> input_tensors(node->num_inputs(), nullptr);
std::vector<Tensor> real_tensors(node->num_inputs());
std::vector<bool> attempted_materialization(node->num_inputs());
std::vector<bool> attempted_tensor_as_shape_conversion(node->num_inputs());
std::vector<ShapeHandle> input_tensors_as_shapes; std::vector<ShapeHandle> input_tensors_as_shapes;
// Create the inference context for this node with the existing input shapes. // Create the inference context for this node with the existing input shapes.
@ -104,78 +101,7 @@ Status ShapeRefiner::AddNode(const Node* node) {
} }
// Run the shape inference function, and return if there was an error. // Run the shape inference function, and return if there was an error.
if (op_reg_data->shape_inference_fn) { TF_RETURN_IF_ERROR(RunShapeFn(node, op_reg_data, c.get()));
TF_RETURN_IF_ERROR(c->Run(op_reg_data->shape_inference_fn));
} else {
TF_RETURN_IF_ERROR(c->Run(shape_inference::UnknownShape));
}
// We must run the shape function repeatedly, in case users write
// shape functions where they only conditionally call input_tensor()
// based on the values of another input tensor.
bool rerun_shape_fn;
do {
// If the result of running shape inference would have benefitted
// from knowing the values of input tensors, try to materialize
// the results of those tensors, and then run the shape inference
// function again using those known tensors.
rerun_shape_fn = false;
// NOTE: It is possible to batch the extraction and
// materialization of inputs, instead of materializing one input
// at a time like we do below. If input-at-a-time computation
// becomes a bottleneck, we could separate ExtractConstantSubgraph
// into two functions: one that returns true if an input is
// derivable from constants, and another function that extracts
// the subgraph for multiple target nodes and executes the whole
// subgraph once.
for (int i = 0; i < c->num_inputs(); ++i) {
if (!c->requested_input_tensor(i)) {
continue;
}
// Check if we have not already filled in the requested input,
// and if not, try to materialize the tensors.
if (!attempted_materialization[i]) {
attempted_materialization[i] = true;
Tensor result;
bool evaluated = false;
TF_RETURN_IF_ERROR(
EvaluateConstantTensorForEdge(node, i, &evaluated, &result));
if (evaluated) {
real_tensors[i] = result;
input_tensors[i] = &real_tensors[i];
// We have more concrete information about a shape,
// so re-run shape inference.
rerun_shape_fn = true;
}
}
if (c->requested_input_tensor_as_partial_shape(i) &&
!attempted_tensor_as_shape_conversion[i]) {
attempted_tensor_as_shape_conversion[i] = true;
if (i >= input_tensors_as_shapes.size()) {
input_tensors_as_shapes.resize(i + 1);
}
ShapeHandle s;
TF_RETURN_IF_ERROR(ConstantPartialShape(c.get(), node, i, &s));
input_tensors_as_shapes[i] = s;
rerun_shape_fn = true;
}
}
if (rerun_shape_fn) {
// We have more information about the shapes on this pass,
// so re-run shape inference.
c->set_input_tensors(input_tensors);
c->set_input_tensors_as_shapes(input_tensors_as_shapes);
if (op_reg_data->shape_inference_fn) {
TF_RETURN_IF_ERROR(op_reg_data->shape_inference_fn(c.get()));
} else {
TF_RETURN_IF_ERROR(shape_inference::UnknownShape(c.get()));
}
}
} while (rerun_shape_fn);
// Store the resulting InferenceContext object in the map. // Store the resulting InferenceContext object in the map.
node_to_context_[node].swap(c); node_to_context_[node].swap(c);
@ -211,6 +137,74 @@ Status ShapeRefiner::SetShape(const Node* node, int output_port,
return Status::OK(); return Status::OK();
} }
Status ShapeRefiner::UpdateNode(const Node* node, bool* refined) {
auto it = node_to_context_.find(node);
if (it == node_to_context_.end()) {
*refined = true;
return AddNode(node);
}
InferenceContext* node_context = it->second.get();
// Give up if the context wasn't successfully built by the AddNode() method.
TF_RETURN_IF_ERROR(node_context->construction_status());
// Check if the shapes of the nodes in the fan-in of this node have changed,
// and if they have update the node input shapes.
for (const Edge* e : node->in_edges()) {
if (e->IsControlEdge()) continue;
Node* input = e->src();
auto iter = node_to_context_.find(input);
if (iter == node_to_context_.end()) {
return errors::FailedPrecondition(
"Input ", e->dst_input(), " ('", input->name(), "') for '",
node->name(), "' was not previously added to ShapeRefiner.");
}
InferenceContext* c = iter->second.get();
DCHECK_GE(e->dst_input(), 0);
if (node_context->set_input(e->dst_input(), c->output(e->src_output()))) {
*refined = true;
}
// Also propagate handle shape and dtype of edges which are carrying
// resource handles.
if (e->src()->output_type(e->src_output()) == DT_RESOURCE) {
if (node_context->set_input_handle_dtype(
e->dst_input(), c->output_handle_dtype(e->src_output()))) {
*refined = true;
}
if (node_context->set_input_handle_shape(
e->dst_input(), c->output_handle_shape(e->src_output()))) {
*refined = true;
}
}
}
if (!*refined) {
// No input shape has changed, we're done
return Status::OK();
}
// Get and run the shape function for this node to update the shapes of the
// outputs.
const OpRegistrationData* op_reg_data;
TF_RETURN_IF_ERROR(ops_registry_->LookUp(node->type_string(), &op_reg_data));
if (op_reg_data->shape_inference_fn == nullptr &&
require_shape_inference_fns_) {
return errors::InvalidArgument(
"No shape inference function exists for op '", node->type_string(),
"', did you forget to define it?");
}
if (!op_reg_data->shape_inference_fn) {
// There is nothing more we can infer
return Status::OK();
}
return RunShapeFn(node, op_reg_data, node_context);
}
Status ShapeRefiner::EvaluateConstantTensorForEdge(const Node* node, Status ShapeRefiner::EvaluateConstantTensorForEdge(const Node* node,
int dst_idx, bool* evaluated, int dst_idx, bool* evaluated,
Tensor* result) { Tensor* result) {
@ -463,4 +457,93 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
return Status::OK(); return Status::OK();
} }
Status ShapeRefiner::RunShapeFn(const Node* node,
const OpRegistrationData* op_reg_data,
shape_inference::InferenceContext* c) {
// This will be filled in with real data in a second pass.
std::vector<const Tensor*> input_tensors(node->num_inputs(), nullptr);
std::vector<Tensor> real_tensors(node->num_inputs());
std::vector<bool> attempted_materialization(node->num_inputs());
std::vector<bool> attempted_tensor_as_shape_conversion(node->num_inputs());
std::vector<ShapeHandle> input_tensors_as_shapes;
// Run the shape inference function, and return if there was an error.
c->set_input_tensors(input_tensors);
c->set_input_tensors_as_shapes(input_tensors_as_shapes);
if (op_reg_data->shape_inference_fn) {
TF_RETURN_IF_ERROR(c->Run(op_reg_data->shape_inference_fn));
} else {
TF_RETURN_IF_ERROR(c->Run(shape_inference::UnknownShape));
}
// We must run the shape function repeatedly, in case users write
// shape functions where they only conditionally call input_tensor()
// based on the values of another input tensor.
bool rerun_shape_fn;
do {
// If the result of running shape inference would have benefitted
// from knowing the values of input tensors, try to materialize
// the results of those tensors, and then run the shape inference
// function again using those known tensors.
rerun_shape_fn = false;
// NOTE: It is possible to batch the extraction and
// materialization of inputs, instead of materializing one input
// at a time like we do below. If input-at-a-time computation
// becomes a bottleneck, we could separate ExtractConstantSubgraph
// into two functions: one that returns true if an input is
// derivable from constants, and another function that extracts
// the subgraph for multiple target nodes and executes the whole
// subgraph once.
for (int i = 0; i < c->num_inputs(); ++i) {
if (!c->requested_input_tensor(i)) {
continue;
}
// Check if we have not already filled in the requested input,
// and if not, try to materialize the tensors.
if (!attempted_materialization[i]) {
attempted_materialization[i] = true;
Tensor result;
bool evaluated = false;
TF_RETURN_IF_ERROR(
EvaluateConstantTensorForEdge(node, i, &evaluated, &result));
if (evaluated) {
real_tensors[i] = result;
input_tensors[i] = &real_tensors[i];
// We have more concrete information about a shape,
// so re-run shape inference.
rerun_shape_fn = true;
}
}
if (c->requested_input_tensor_as_partial_shape(i) &&
!attempted_tensor_as_shape_conversion[i]) {
attempted_tensor_as_shape_conversion[i] = true;
if (i >= input_tensors_as_shapes.size()) {
input_tensors_as_shapes.resize(i + 1);
}
ShapeHandle s;
TF_RETURN_IF_ERROR(ConstantPartialShape(c, node, i, &s));
input_tensors_as_shapes[i] = s;
rerun_shape_fn = true;
}
}
if (rerun_shape_fn) {
// We have more information about the shapes on this pass,
// so re-run shape inference.
c->set_input_tensors(input_tensors);
c->set_input_tensors_as_shapes(input_tensors_as_shapes);
if (op_reg_data->shape_inference_fn) {
TF_RETURN_IF_ERROR(op_reg_data->shape_inference_fn(c));
} else {
TF_RETURN_IF_ERROR(shape_inference::UnknownShape(c));
}
}
} while (rerun_shape_fn);
return Status::OK();
}
} // namespace tensorflow } // namespace tensorflow

View File

@ -55,6 +55,11 @@ class ShapeRefiner {
Status SetShape(const Node* node, int output_port, Status SetShape(const Node* node, int output_port,
shape_inference::ShapeHandle shape); shape_inference::ShapeHandle shape);
// Update the input shapes of node in case the shapes of the fan-ins of 'node'
// have themselves been modified (For example, in case of incremental shape
// refinement). Sets refined to true if any of the node shape has changed.
Status UpdateNode(const Node* node, bool* refined);
// Returns the InferenceContext for 'node', if present. // Returns the InferenceContext for 'node', if present.
shape_inference::InferenceContext* GetContext(const Node* node) const { shape_inference::InferenceContext* GetContext(const Node* node) const {
auto it = node_to_context_.find(node); auto it = node_to_context_.find(node);
@ -108,6 +113,9 @@ class ShapeRefiner {
const Node* node, int dst_idx, const Node* node, int dst_idx,
shape_inference::ShapeHandle* result); shape_inference::ShapeHandle* result);
Status RunShapeFn(const Node* node, const OpRegistrationData* op_reg_data,
shape_inference::InferenceContext* c);
int32 graph_def_version_; int32 graph_def_version_;
const OpRegistryInterface* const ops_registry_; const OpRegistryInterface* const ops_registry_;

View File

@ -768,5 +768,38 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_ConcatInvalidDimValue) {
m.AddNode(result).error_message()); m.AddNode(result).error_message());
} }
TEST(ShapeRefinerTest, IncrementalUpdates) {
Scope root = Scope::NewRootScope();
Graph* g = root.graph();
Node* queue;
TF_CHECK_OK(NodeBuilder("queue", "FIFOQueueV2")
.Attr("component_types", {DT_FLOAT})
.Finalize(g, &queue));
Node* dequeue;
TF_CHECK_OK(NodeBuilder("dequeue", "QueueDequeueV2")
.Attr("component_types", {DT_FLOAT})
.Input(queue)
.Finalize(g, &dequeue));
ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
TF_ASSERT_OK(m.AddNode(queue));
TF_ASSERT_OK(m.AddNode(dequeue));
// At this point, the shapes of the dequeued tensor are unknown.
shape_inference::InferenceContext* ctx = m.GetContext(dequeue);
EXPECT_EQ("?", ctx->DebugString(ctx->output(0)));
// Inject a shape, and incrementally propagate it to the dequeue op.
ctx = m.GetContext(queue);
shape_inference::ShapeHandle shp = ctx->MakeShape({3, 7});
ctx->set_output_handle_shape(0, shp);
ctx->set_output_handle_dtype(0, DT_FLOAT);
bool refined = false;
TF_ASSERT_OK(m.UpdateNode(dequeue, &refined));
EXPECT_TRUE(refined);
ctx = m.GetContext(dequeue);
EXPECT_EQ("[3,7]", ctx->DebugString(ctx->output(0)));
}
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow

View File

@ -582,7 +582,7 @@ string Print(const GraphDef& gdef) {
for (size_t i = 0; i < arg.size(); ++i) { for (size_t i = 0; i < arg.size(); ++i) {
const NodeDef* n = arg[i]; const NodeDef* n = arg[i];
if (i > 0) strings::StrAppend(&out, ", "); if (i > 0) strings::StrAppend(&out, ", ");
CHECK_EQ(2, n->attr_size()); CHECK_GE(n->attr_size(), 2);
strings::StrAppend(&out, n->name(), ":", get_type(*n)); strings::StrAppend(&out, n->name(), ":", get_type(*n));
} }
strings::StrAppend(&out, ") -> ("); strings::StrAppend(&out, ") -> (");

View File

@ -191,6 +191,17 @@ class InferenceContext {
return s; return s;
} }
// Set the shape of the input in position idx. This requires idx to be in the
// [0, num_inputs) range. Returns true iff the stored input shape has been
// updated with a different handle.
bool set_input(int idx, ShapeHandle shape) {
if (!inputs_[idx].SameHandle(shape)) {
inputs_[idx] = shape;
return true;
} else {
return false;
}
}
ShapeHandle input(int64 idx) const { return inputs_[idx]; } ShapeHandle input(int64 idx) const { return inputs_[idx]; }
Status input(StringPiece input_name, std::vector<ShapeHandle>* output) const; Status input(StringPiece input_name, std::vector<ShapeHandle>* output) const;
int num_inputs() const { return inputs_.size(); } int num_inputs() const { return inputs_.size(); }
@ -430,15 +441,53 @@ class InferenceContext {
// and dtypes of tensors which can be accessed via the handle. These methods // and dtypes of tensors which can be accessed via the handle. These methods
// propagate that information. Output handle dtypes and shapes are ignored if // propagate that information. Output handle dtypes and shapes are ignored if
// the output tensor is not of type DT_RESOURCE. // the output tensor is not of type DT_RESOURCE.
// Set the shape corresponding to the resource in position idx. This requires
// idx to be in the [0, num_inputs) range. Returns true iff the stored shape
// has been updated with a different handle.
bool set_input_handle_shape(int idx, ShapeHandle shape) {
if (!input_handle_shape_[idx].SameHandle(shape)) {
input_handle_shape_[idx] = shape;
return true;
}
return false;
}
// Set the type corresponding to the resource in position idx. This requires
// idx to be in the [0, num_inputs) range. Returns true iff the stored type
// has been updated.
bool set_input_handle_dtype(int idx, DataType dtype) {
if (input_handle_dtype_[idx] != dtype) {
input_handle_dtype_[idx] = dtype;
return true;
}
return false;
}
ShapeHandle input_handle_shape(int idx); ShapeHandle input_handle_shape(int idx);
DataType input_handle_dtype(int idx) const { DataType input_handle_dtype(int idx) const {
return input_handle_dtype_[idx]; return input_handle_dtype_[idx];
} }
void set_output_handle_shape(int idx, ShapeHandle shape) {
output_handle_shape_[idx] = shape; // Set the shape corresponding to the resource in position idx. This requires
// idx to be in the [0, num_outputs) range.
// Returns true iff the stored shape has been updated with a different handle.
bool set_output_handle_shape(int idx, ShapeHandle shape) {
if (!output_handle_shape_[idx].SameHandle(shape)) {
output_handle_shape_[idx] = shape;
return true;
}
return false;
} }
void set_output_handle_dtype(int idx, DataType dtype) {
output_handle_dtype_[idx] = dtype; // Set the type corresponding to the resource in position idx. This requires
// idx to be in the [0, num_outputs) range. Returns true iff the stored type
// has been updated.
bool set_output_handle_dtype(int idx, DataType dtype) {
if (output_handle_dtype_[idx] != dtype) {
output_handle_dtype_[idx] = dtype;
return true;
}
return false;
} }
ShapeHandle output_handle_shape(int idx) const { ShapeHandle output_handle_shape(int idx) const {
return output_handle_shape_[idx]; return output_handle_shape_[idx];

View File

@ -558,6 +558,11 @@ TEST_F(ShapeInferenceTest, MergeShape) {
EXPECT_TRUE(SameHandle(c.Dim(s_1_u, 0), c.Dim(out, 0))); EXPECT_TRUE(SameHandle(c.Dim(s_1_u, 0), c.Dim(out, 0)));
EXPECT_TRUE(SameHandle(c.Dim(s_u_2, 1), c.Dim(out, 1))); EXPECT_TRUE(SameHandle(c.Dim(s_u_2, 1), c.Dim(out, 1)));
auto s_u1 = c.UnknownShapeOfRank(1);
auto s_u2 = c.UnknownShapeOfRank(1);
TF_EXPECT_OK(c.Merge(s_u1, s_u2, &out));
EXPECT_TRUE(SameHandle(s_u1, out));
// Incompatible merges give errors and set out to nullptr. // Incompatible merges give errors and set out to nullptr.
out = s_unknown; out = s_unknown;
EXPECT_TRUE( EXPECT_TRUE(

View File

@ -58,6 +58,7 @@ cc_library(
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu",
"//tensorflow/core:direct_session", "//tensorflow/core:direct_session",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:utils", "//tensorflow/core/grappler:utils",
"//tensorflow/core/kernels:ops_util", "//tensorflow/core/kernels:ops_util",
], ],

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <memory> #include <memory>
#include "tensorflow/cc/training/queue_runner.h" #include "tensorflow/cc/training/queue_runner.h"
#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
@ -111,6 +112,8 @@ Status SingleMachine::Run(const GraphDef& graph_def,
for (auto node : *init_metadata_.mutable_cost_graph()->mutable_node()) { for (auto node : *init_metadata_.mutable_cost_graph()->mutable_node()) {
node.clear_compute_cost(); node.clear_compute_cost();
} }
// Also clear the timeline to save memory
init_metadata_.clear_step_stats();
} }
for (int i = 0; i < queue_runner_defs_.size(); ++i) { for (int i = 0; i < queue_runner_defs_.size(); ++i) {
std::unique_ptr<QueueRunner> queue_runner; std::unique_ptr<QueueRunner> queue_runner;
@ -133,15 +136,17 @@ Status SingleMachine::Run(const GraphDef& graph_def,
} }
} }
TF_RETURN_IF_ERROR(RunWithTimeout(feed, fetch, metadata));
if (metadata) { if (metadata) {
// Add the costs of initialization and the queue runners. TF_RETURN_IF_ERROR(RunWithTimeout(feed, fetch, metadata));
metadata->MergeFrom(init_metadata_); // Merge the costs of the initialization and the queue runners.
return coordinator_->ExportCostGraph(metadata->mutable_cost_graph()); CostGraphDef queue_costs;
TF_RETURN_IF_ERROR(coordinator_->ExportCostGraph(&queue_costs));
MergeCosts(metadata->mutable_cost_graph(), init_metadata_.cost_graph(),
queue_costs);
} else { } else {
return Status::OK(); return RunWithTimeout(feed, fetch, nullptr);
} }
return Status::OK();
} }
Status SingleMachine::RunWithTimeout( Status SingleMachine::RunWithTimeout(
@ -249,5 +254,36 @@ Status SingleMachine::ResetSession() {
return Status::OK(); return Status::OK();
} }
void SingleMachine::MergeCosts(CostGraphDef* graph_costs,
const CostGraphDef& init_costs,
const CostGraphDef& queue_costs) {
graph_costs->mutable_node()->Reserve(graph_costs->node_size() +
init_costs.node_size() +
queue_costs.node_size());
std::unordered_set<string> nodes_seen;
for (const auto& node : graph_costs->node()) {
nodes_seen.insert(node.name());
}
// The costs obtained by running the main graph could be more stable than
// the one we get from the queue runners since the queue runners run
// asynchronously.
for (const auto& node : queue_costs.node()) {
if (nodes_seen.find(node.name()) != nodes_seen.end()) {
continue;
}
graph_costs->add_node()->MergeFrom(node);
}
// Don't overwrite the costs with that generated during initialization since
// these are possibly outdated.
for (const auto& node : init_costs.node()) {
if (nodes_seen.find(node.name()) != nodes_seen.end()) {
continue;
}
graph_costs->add_node()->MergeFrom(node);
}
}
} // namespace grappler } // namespace grappler
} // namespace tensorflow } // namespace tensorflow

View File

@ -47,6 +47,8 @@ class SingleMachine : public Cluster {
RunMetadata* run_metadata, int64 timeout_s); RunMetadata* run_metadata, int64 timeout_s);
Status ResetSession(); Status ResetSession();
Status CloseSession(bool use_timeout); Status CloseSession(bool use_timeout);
void MergeCosts(CostGraphDef* graph_costs, const CostGraphDef& init_costs,
const CostGraphDef& queue_costs);
const int num_gpus_; const int num_gpus_;
std::unique_ptr<Session> session_; std::unique_ptr<Session> session_;

View File

@ -159,6 +159,121 @@ TEST_F(SingleMachineTest, InitializationMemory) {
EXPECT_TRUE(found); EXPECT_TRUE(found);
} }
namespace {
template <class T>
inline void SetNodeAttr(const string& key, const T& value, NodeDef* node) {
AttrValue attr_value;
SetAttrValue(value, &attr_value);
auto* attr_map = node->mutable_attr();
(*attr_map)[key] = attr_value;
}
template <>
inline void SetNodeAttr(const string& key, const Tensor& tensor,
NodeDef* node) {
TensorProto tensor_proto;
tensor.AsProtoTensorContent(&tensor_proto);
SetNodeAttr(key, tensor_proto, node);
}
} // namespace
TEST_F(SingleMachineTest, PersistentMemory) {
// Build a hashtable and its initialization graph.
GrapplerItem item;
const DataType key_dtype = DT_INT64;
const DataType data_dtype = DT_INT64;
NodeDef* hashtable_node = item.graph.add_node();
hashtable_node->set_op("HashTable");
hashtable_node->set_name("hash_table");
SetNodeAttr("key_dtype", key_dtype, hashtable_node);
SetNodeAttr("value_dtype", data_dtype, hashtable_node);
// Initial hashtable keys and values
NodeDef* keys_node = item.graph.add_node();
keys_node->set_op("Const");
keys_node->set_name("table_keys");
SetNodeAttr("dtype", key_dtype, keys_node);
Tensor keys(key_dtype, TensorShape{2});
keys.vec<int64>()(0) = 123;
keys.vec<int64>()(1) = 321;
SetNodeAttr("value", keys, keys_node);
NodeDef* values_node = item.graph.add_node();
values_node->set_op("Const");
values_node->set_name("table_values");
SetNodeAttr("dtype", data_dtype, values_node);
Tensor values(data_dtype, TensorShape{2});
values.vec<int64>()(0) = 789;
values.vec<int64>()(1) = 987;
SetNodeAttr("value", values, values_node);
// InitializeTable node
NodeDef* init_table_node = item.graph.add_node();
init_table_node->set_op("InitializeTable");
init_table_node->set_name("initialize_table");
SetNodeAttr("Tkey", key_dtype, init_table_node);
SetNodeAttr("Tval", data_dtype, init_table_node);
*init_table_node->add_input() = "hash_table";
*init_table_node->add_input() = "table_keys";
*init_table_node->add_input() = "table_values";
item.init_ops.push_back(init_table_node->name());
// Key to lookup
NodeDef* query_node = item.graph.add_node();
query_node->set_op("Const");
query_node->set_name("query");
SetNodeAttr("dtype", key_dtype, query_node);
Tensor query(key_dtype, TensorShape({}));
query.flat<int64>()(0) = 0;
SetNodeAttr("value", query, query_node);
// Default return value of hashtable lookup
NodeDef* default_value_node = item.graph.add_node();
default_value_node->set_op("Const");
default_value_node->set_name("default_table_value");
SetNodeAttr("dtype", data_dtype, default_value_node);
Tensor dflt(data_dtype, TensorShape({}));
dflt.flat<int64>()(0) = 456;
SetNodeAttr("value", dflt, default_value_node);
// HashTable lookup node
NodeDef* lookup_node = item.graph.add_node();
lookup_node->set_op("LookupTableFind");
lookup_node->set_name("table_lookup");
SetNodeAttr("Tin", key_dtype, lookup_node);
SetNodeAttr("Tout", data_dtype, lookup_node);
*lookup_node->add_input() = "hash_table";
*lookup_node->add_input() = "query";
*lookup_node->add_input() = "default_table_value";
item.fetch.push_back(lookup_node->name());
// Run the graph
TF_CHECK_OK(cluster_->Initialize(item));
RunMetadata metadata;
TF_CHECK_OK(cluster_->Run(item.graph, item.feed, item.fetch, &metadata));
// Check the cost model.
bool found_table_init = false;
bool found_hashtable = false;
for (const auto& node : metadata.cost_graph().node()) {
if (node.name() == "hash_table") {
found_hashtable = true;
// Persistent memory usage should be 0 since it's recorded as part of the
// initialize_table op.
EXPECT_EQ(0, node.host_persistent_memory_size());
EXPECT_EQ(0, node.device_persistent_memory_size());
} else if (node.name() == "initialize_table") {
found_table_init = true;
// Persistent memory should hold 2 keys and 2 values.
EXPECT_LE(4 * sizeof(int64), node.host_persistent_memory_size());
EXPECT_EQ(0, node.device_persistent_memory_size());
}
}
EXPECT_TRUE(found_table_init);
EXPECT_TRUE(found_hashtable);
}
} // namespace } // namespace
} // namespace grappler } // namespace grappler
} // namespace tensorflow } // namespace tensorflow

View File

@ -50,11 +50,13 @@ cc_test(
args = ["--heap_check=local"], # The GPU tracer leaks memory args = ["--heap_check=local"], # The GPU tracer leaks memory
deps = [ deps = [
":graph_properties", ":graph_properties",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:scope",
"//tensorflow/core:framework",
"//tensorflow/core:lib_proto_parsing", "//tensorflow/core:lib_proto_parsing",
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:grappler_item_builder",
"//tensorflow/core/grappler/clusters:single_machine", "//tensorflow/core/grappler/clusters:single_machine",
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
], ],

View File

@ -15,6 +15,9 @@ limitations under the License.
#include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/costs/graph_properties.h"
#include <queue>
#include <unordered_map>
#include <unordered_set>
#include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_constructor.h"
@ -31,6 +34,76 @@ Status GraphProperties::InferStatically() {
Status s = ImportGraphDef(options, item_.graph, &graph, &shape_refiner); Status s = ImportGraphDef(options, item_.graph, &graph, &shape_refiner);
TF_RETURN_IF_ERROR(s); TF_RETURN_IF_ERROR(s);
// List the resources and the nodes using them
std::unordered_map<const Node*, std::unordered_set<const Node*>> resources;
for (const Node* const node : graph.nodes()) {
for (int i = 0; i < node->num_inputs(); ++i) {
if (node->input_type(i) == DataType::DT_RESOURCE) {
const Node* resource;
TF_CHECK_OK(node->input_node(i, &resource));
resources[resource].insert(node);
}
}
}
// If we found a resource, try to propagate the shapes through it.
bool done = true;
do {
std::queue<const Node*> new_shapes;
for (const auto& resource_data : resources) {
const Node* qnode = resource_data.first;
StringPiece type(qnode->type_string());
if (!type.ends_with("QueueV2")) {
continue;
}
auto qctx = shape_refiner.GetContext(qnode);
if (!qctx) {
continue;
}
DataType queue_type = qctx->output_handle_dtype(0);
shape_inference::ShapeHandle queue_shp = qctx->output_handle_shape(0);
if (qctx->FullyDefined(queue_shp) && queue_type != DT_INVALID) {
continue;
}
for (const auto& node : resource_data.second) {
auto ctx = shape_refiner.GetContext(node);
if (!ctx) {
continue;
}
if (node->type_string().find("Enqueue") != std::string::npos) {
if (ctx->num_inputs() == 2) {
const DataType dtype = node->input_type(1);
if (queue_type == DT_INVALID) {
queue_type = dtype;
} else {
CHECK_EQ(queue_type, dtype);
}
shape_inference::ShapeHandle shp = ctx->input(1);
TF_RETURN_IF_ERROR(qctx->Merge(queue_shp, shp, &queue_shp));
}
}
}
if (qctx->set_output_handle_dtype(0, queue_type) ||
qctx->set_output_handle_shape(0, queue_shp)) {
new_shapes.push(qnode);
}
}
// Propagate the shapes in the transitive fan-out of the queue.
done = new_shapes.empty();
while (!new_shapes.empty()) {
const Node* n = new_shapes.front();
new_shapes.pop();
for (const Node* fanout : n->out_nodes()) {
bool updated = false;
TF_RETURN_IF_ERROR(shape_refiner.UpdateNode(fanout, &updated));
if (updated) {
new_shapes.push(fanout);
}
}
}
} while (!done);
for (const Node* const node : graph.nodes()) { for (const Node* const node : graph.nodes()) {
VLOG(1) << "<Node> " << node->name(); VLOG(1) << "<Node> " << node->name();
auto ctx = shape_refiner.GetContext(node); auto ctx = shape_refiner.GetContext(node);

View File

@ -14,6 +14,9 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/grappler/clusters/single_machine.h" #include "tensorflow/core/grappler/clusters/single_machine.h"
#include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
@ -129,6 +132,101 @@ TEST_F(GraphPropertiesTest, DynamicProperties) {
} }
} }
TEST_F(GraphPropertiesTest, VarHandles) {
GrapplerItem item;
TF_CHECK_OK(NodeDefBuilder("Var", "VarHandleOp")
.Attr("dtype", DT_FLOAT)
.Attr("shape", TensorShape({3, 7}))
.Finalize(item.graph.add_node()));
TF_CHECK_OK(NodeDefBuilder("VarRead", "ReadVariableOp")
.Attr("dtype", DT_FLOAT)
.Input("Var", 0, DT_RESOURCE)
.Finalize(item.graph.add_node()));
GraphProperties properties(item);
TF_CHECK_OK(properties.InferStatically());
const auto props = properties.GetOutputProperties("VarRead");
EXPECT_EQ(1, props.size());
const OpInfo::TensorProperties& prop = props[0];
EXPECT_EQ(DT_FLOAT, prop.dtype());
EXPECT_FALSE(prop.shape().unknown_rank());
EXPECT_EQ(2, prop.shape().dim_size());
EXPECT_EQ(3, prop.shape().dim(0).size());
EXPECT_EQ(7, prop.shape().dim(1).size());
}
TEST_F(GraphPropertiesTest, Queues) {
// Create a graph with known input shapes, and propagate the shapes through a
// couple of queues.
tensorflow::Scope root = tensorflow::Scope::NewRootScope();
auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT});
Output rnd =
ops::RandomNormal(root.WithOpName("rnd"), {3, 7}, DataType::DT_FLOAT);
Output square1 = ops::Square(root.WithOpName("Square1"), rnd);
auto enqueue1 = ops::QueueEnqueue(root.WithOpName("Enqueue1"), q1, {square1});
auto dequeue1 =
ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
auto q2 =
ops::RandomShuffleQueue(root.WithOpName("Queue2"), {DataType::DT_FLOAT});
Output square2 = ops::Square(root.WithOpName("Square2"), dequeue1[0]);
auto enqueue2 = ops::QueueEnqueue(root.WithOpName("Enqueue2"), q2, {square2});
auto dequeue2 =
ops::QueueDequeue(root.WithOpName("Dequeue2"), q2, {DataType::DT_FLOAT});
auto q3 =
ops::RandomShuffleQueue(root.WithOpName("Queue3"), {DataType::DT_FLOAT});
auto dequeue3 =
ops::QueueDequeue(root.WithOpName("Dequeue3"), q3, {DataType::DT_FLOAT});
auto q4 =
ops::RandomShuffleQueue(root.WithOpName("Queue4"), {DataType::DT_FLOAT});
auto enqueue4 = ops::QueueEnqueue(root.WithOpName("Enqueue4"), q4, {square2});
auto enqueue4_2 =
ops::QueueEnqueue(root.WithOpName("Enqueue4_2"), q4, {dequeue3[0]});
auto dequeue4 =
ops::QueueDequeue(root.WithOpName("Dequeue4"), q4, {DataType::DT_FLOAT});
GrapplerItem item;
TF_CHECK_OK(root.ToGraphDef(&item.graph));
GraphProperties properties(item);
TF_CHECK_OK(properties.InferStatically());
const auto props1 = properties.GetOutputProperties("Dequeue1");
EXPECT_EQ(1, props1.size());
const OpInfo::TensorProperties& prop1 = props1[0];
EXPECT_EQ(DT_FLOAT, prop1.dtype());
EXPECT_FALSE(prop1.shape().unknown_rank());
EXPECT_EQ(2, prop1.shape().dim_size());
EXPECT_EQ(3, prop1.shape().dim(0).size());
EXPECT_EQ(7, prop1.shape().dim(1).size());
const auto props2 = properties.GetOutputProperties("Dequeue2");
EXPECT_EQ(1, props2.size());
const OpInfo::TensorProperties& prop2 = props2[0];
EXPECT_EQ(DT_FLOAT, prop2.dtype());
EXPECT_FALSE(prop2.shape().unknown_rank());
EXPECT_EQ(2, prop2.shape().dim_size());
EXPECT_EQ(3, prop2.shape().dim(0).size());
EXPECT_EQ(7, prop2.shape().dim(1).size());
// The dequeue3 op shape is unknown. The square2 op shape is known. Verify
// that we merge the 2 properly to determine the shape of the data coming out
// of the queue.
const auto props4 = properties.GetOutputProperties("Dequeue4");
EXPECT_EQ(1, props4.size());
const OpInfo::TensorProperties& prop4 = props4[0];
EXPECT_EQ(DT_FLOAT, prop4.dtype());
EXPECT_FALSE(prop4.shape().unknown_rank());
EXPECT_EQ(2, prop4.shape().dim_size());
EXPECT_EQ(3, prop4.shape().dim(0).size());
EXPECT_EQ(7, prop4.shape().dim(1).size());
}
} // namespace } // namespace
} // namespace grappler } // namespace grappler
} // namespace tensorflow } // namespace tensorflow

View File

@ -2042,6 +2042,7 @@ tf_kernel_library(
deps = [ deps = [
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core/platform/default/build_config:cublas_plugin",
"@local_config_cuda//cuda:cusolver", "@local_config_cuda//cuda:cusolver",
], ],
) )
@ -2322,7 +2323,9 @@ tf_kernel_library(
prefix = "fft_ops", prefix = "fft_ops",
deps = MATH_DEPS + [ deps = MATH_DEPS + [
"//tensorflow/core:spectral_ops_op_lib", "//tensorflow/core:spectral_ops_op_lib",
], ] + if_cuda([
"//tensorflow/core/platform/default/build_config:cufft_plugin",
]),
) )
tf_kernel_library( tf_kernel_library(
@ -2626,7 +2629,9 @@ tf_kernel_library(
"@libxsmm_archive//:xsmm_avx", "@libxsmm_archive//:xsmm_avx",
], ],
"//conditions:default": [], "//conditions:default": [],
}), }) + if_cuda([
"//tensorflow/core/platform/default/build_config:cudnn_plugin",
]),
) )
tf_kernel_library( tf_kernel_library(

View File

@ -19,6 +19,9 @@ limitations under the License.
#include "tensorflow/core/kernels/crop_and_resize_op.h" #include "tensorflow/core/kernels/crop_and_resize_op.h"
#include <functional>
#include <string>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/register_types.h"
@ -26,10 +29,13 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#if GOOGLE_CUDA #if GOOGLE_CUDA
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA
@ -37,41 +43,67 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice; typedef Eigen::GpuDevice GPUDevice;
using Callback = std::function<void()>;
static inline void ParseAndCheckBoxSizes(OpKernelContext* context, namespace {
const Tensor& boxes,
const Tensor& box_ind, static inline Status ParseAndCheckBoxSizes(const Tensor& boxes,
int* num_boxes) { const Tensor& box_index,
if (boxes.NumElements() == 0 && box_ind.NumElements() == 0) { int* num_boxes) {
if (boxes.NumElements() == 0 && box_index.NumElements() == 0) {
*num_boxes = 0; *num_boxes = 0;
return; return Status::OK();
} }
// The shape of 'boxes' is [num_boxes, 4]. // The shape of 'boxes' is [num_boxes, 4].
OP_REQUIRES(context, boxes.dims() == 2, if (boxes.dims() != 2) {
errors::InvalidArgument("boxes must be 2-D", return errors::InvalidArgument("boxes must be 2-D",
boxes.shape().DebugString())); boxes.shape().DebugString());
}
*num_boxes = boxes.dim_size(0); *num_boxes = boxes.dim_size(0);
OP_REQUIRES(context, boxes.dim_size(1) == 4, if (boxes.dim_size(1) != 4) {
errors::InvalidArgument("boxes must have 4 columns")); return errors::InvalidArgument("boxes must have 4 columns");
}
// The shape of 'box_ind' is [num_boxes]. // The shape of 'box_index' is [num_boxes].
OP_REQUIRES(context, box_ind.dims() == 1, if (box_index.dims() != 1) {
errors::InvalidArgument("box_ind must be 1-D", return errors::InvalidArgument("box_index must be 1-D",
box_ind.shape().DebugString())); box_index.shape().DebugString());
OP_REQUIRES(context, box_ind.dim_size(0) == *num_boxes, }
errors::InvalidArgument("box_ind has incompatible shape")); if (box_index.dim_size(0) != *num_boxes) {
return errors::InvalidArgument("box_index has incompatible shape");
}
return Status::OK();
} }
// Verifies that all values in box_ind are in [0, batch). // Conditionally calls the compute callback if all values in box_index are in
// [0, batch_size) then calls done.
template <typename Device> template <typename Device>
inline void CheckValidBoxInd( inline void RunIfBoxIndexIsValid(
OpKernelContext* context, OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index,
typename TTypes<int32, 1>::ConstTensor box_ind_data, int batch); int batch_size, Callback compute, Callback done);
// Specialization of CheckValidBoxIndex for a CPUDevice.
template <>
inline void RunIfBoxIndexIsValid<CPUDevice>(
OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index,
int batch_size, Callback compute, Callback done) {
const int num_boxes = box_index.dimension(0);
for (int b = 0; b < num_boxes; ++b) {
OP_REQUIRES_ASYNC(
context, FastBoundsCheck(box_index(b), batch_size),
errors::OutOfRange("box_index has values outside [0, batch_size)"),
done);
}
compute();
done();
}
} // namespace
template <typename Device, typename T> template <typename Device, typename T>
class CropAndResizeOp : public OpKernel { class CropAndResizeOp : public AsyncOpKernel {
public: public:
explicit CropAndResizeOp(OpKernelConstruction* context) : OpKernel(context) { explicit CropAndResizeOp(OpKernelConstruction* context)
: AsyncOpKernel(context) {
string method; string method;
OP_REQUIRES_OK(context, context->GetAttr("method", &method)); OP_REQUIRES_OK(context, context->GetAttr("method", &method));
OP_REQUIRES(context, method == "bilinear", OP_REQUIRES(context, method == "bilinear",
@ -80,69 +112,77 @@ class CropAndResizeOp : public OpKernel {
&extrapolation_value_)); &extrapolation_value_));
} }
void Compute(OpKernelContext* context) override { void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
// The shape of 'image' is [batch, image_height, image_width, channels]. // The shape of 'image' is [batch_size, image_height, image_width,
// channels].
const Tensor& image = context->input(0); const Tensor& image = context->input(0);
OP_REQUIRES(context, image.dims() == 4,
errors::InvalidArgument("input image must be 4-D",
image.shape().DebugString()));
const int batch = image.dim_size(0);
const int image_height = image.dim_size(1);
const int image_width = image.dim_size(2);
const int depth = image.dim_size(3);
OP_REQUIRES(context, image_height > 0 && image_width > 0,
errors::InvalidArgument("image dimensions must be positive"));
// The shape of 'boxes' is [num_boxes, 4]. // The shape of 'boxes' is [num_boxes, 4].
const Tensor& boxes = context->input(1); const Tensor& boxes = context->input(1);
// The shape of 'box_index' is [num_boxes].
// The shape of 'box_ind' is [num_boxes]. const Tensor& box_index = context->input(2);
const Tensor& box_ind = context->input(2);
int num_boxes = 0;
ParseAndCheckBoxSizes(context, boxes, box_ind, &num_boxes);
// The shape of 'crop_size' is [2]. // The shape of 'crop_size' is [2].
const Tensor& crop_size = context->input(3); const Tensor& crop_size = context->input(3);
OP_REQUIRES(context, crop_size.dims() == 1, // Validate inputs dimensions.
errors::InvalidArgument("crop_size must be 1-D", OP_REQUIRES_ASYNC(context, image.dims() == 4,
crop_size.shape().DebugString())); errors::InvalidArgument("input image must be 4-D",
OP_REQUIRES(context, crop_size.dim_size(0) == 2, image.shape().DebugString()),
errors::InvalidArgument("crop_size must have two elements", done);
crop_size.shape().DebugString())); const int batch_size = image.dim_size(0);
const int image_height = image.dim_size(1);
const int image_width = image.dim_size(2);
const int depth = image.dim_size(3);
OP_REQUIRES_ASYNC(
context, image_height > 0 && image_width > 0,
errors::InvalidArgument("image dimensions must be positive"), done);
int num_boxes = 0;
OP_REQUIRES_OK_ASYNC(
context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done);
OP_REQUIRES_ASYNC(context, crop_size.dims() == 1,
errors::InvalidArgument("crop_size must be 1-D",
crop_size.shape().DebugString()),
done);
OP_REQUIRES_ASYNC(
context, crop_size.dim_size(0) == 2,
errors::InvalidArgument("crop_size must have two elements",
crop_size.shape().DebugString()),
done);
// Copy and validate crop sizes.
auto crop_size_vec = crop_size.vec<int32>(); auto crop_size_vec = crop_size.vec<int32>();
const int crop_height = internal::SubtleMustCopy(crop_size_vec(0)); const int crop_height = internal::SubtleMustCopy(crop_size_vec(0));
const int crop_width = internal::SubtleMustCopy(crop_size_vec(1)); const int crop_width = internal::SubtleMustCopy(crop_size_vec(1));
OP_REQUIRES(context, crop_height > 0 && crop_width > 0, OP_REQUIRES_ASYNC(
errors::InvalidArgument("crop dimensions must be positive")); context, crop_height > 0 && crop_width > 0,
errors::InvalidArgument("crop dimensions must be positive"), done);
// Allocate output tensor. // Allocate output tensor.
Tensor* output = nullptr; Tensor* output = nullptr;
OP_REQUIRES_OK( OP_REQUIRES_OK_ASYNC(
context, context,
context->allocate_output( context->allocate_output(
0, TensorShape({num_boxes, crop_height, crop_width, depth}), 0, TensorShape({num_boxes, crop_height, crop_width, depth}),
&output)); &output),
done);
typename TTypes<T, 4>::ConstTensor image_data = image.tensor<T, 4>(); auto compute_callback = [this, context, output]() {
typename TTypes<float, 2>::ConstTensor boxes_data = const Tensor& image = context->input(0);
boxes.tensor<float, 2>(); const Tensor& boxes = context->input(1);
typename TTypes<int32, 1>::ConstTensor box_ind_data = const Tensor& box_index = context->input(2);
box_ind.tensor<int32, 1>(); const bool status = functor::CropAndResize<Device, T>()(
typename TTypes<float, 4>::Tensor crops_data = output->tensor<float, 4>(); context->eigen_device<Device>(), image.tensor<T, 4>(),
boxes.tensor<float, 2>(), box_index.tensor<int32, 1>(),
extrapolation_value_, output->tensor<float, 4>());
if (!status) {
context->SetStatus(
errors::Internal("Failed launch CropAndResizeKernel."));
}
};
CheckValidBoxInd<Device>(context, box_ind_data, batch); RunIfBoxIndexIsValid<Device>(context, box_index.tensor<int32, 1>(),
batch_size, std::move(compute_callback),
bool status = functor::CropAndResize<Device, T>()( std::move(done));
context->eigen_device<Device>(), image_data, boxes_data, box_ind_data,
extrapolation_value_, crops_data);
if (!status) {
context->SetStatus(
errors::Internal("Failed launch CropAndResizeKernel."));
}
} }
private: private:
@ -155,10 +195,10 @@ template <typename T>
struct CropAndResize<CPUDevice, T> { struct CropAndResize<CPUDevice, T> {
bool operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor image, bool operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor image,
typename TTypes<float, 2>::ConstTensor boxes, typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<int32, 1>::ConstTensor box_ind, typename TTypes<int32, 1>::ConstTensor box_index,
float extrapolation_value, float extrapolation_value,
typename TTypes<float, 4>::Tensor crops) { typename TTypes<float, 4>::Tensor crops) {
const int batch = image.dimension(0); const int batch_size = image.dimension(0);
const int image_height = image.dimension(1); const int image_height = image.dimension(1);
const int image_width = image.dimension(2); const int image_width = image.dimension(2);
@ -173,8 +213,8 @@ struct CropAndResize<CPUDevice, T> {
const float y2 = boxes(b, 2); const float y2 = boxes(b, 2);
const float x2 = boxes(b, 3); const float x2 = boxes(b, 3);
const int32 b_in = box_ind(b); const int32 b_in = box_index(b);
if (b_in < 0 || b_in >= batch) { if (!FastBoundsCheck(b_in, batch_size)) {
continue; continue;
} }
@ -235,89 +275,94 @@ struct CropAndResize<CPUDevice, T> {
return true; return true;
} }
}; };
} // namespace functor } // namespace functor
template <typename Device, typename T> template <typename Device, typename T>
class CropAndResizeGradImageOp : public OpKernel { class CropAndResizeGradImageOp : public AsyncOpKernel {
public: public:
explicit CropAndResizeGradImageOp(OpKernelConstruction* context) explicit CropAndResizeGradImageOp(OpKernelConstruction* context)
: OpKernel(context) { : AsyncOpKernel(context) {
string method; string method;
OP_REQUIRES_OK(context, context->GetAttr("method", &method)); OP_REQUIRES_OK(context, context->GetAttr("method", &method));
OP_REQUIRES(context, method == "bilinear", OP_REQUIRES(context, method == "bilinear",
errors::InvalidArgument("method must be 'bilinear'", method)); errors::InvalidArgument("method must be 'bilinear'", method));
} }
void Compute(OpKernelContext* context) override { void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
// The shape of 'grads' is [num_boxes, crop_height, crop_width, depth]. // The shape of 'grads' is [num_boxes, crop_height, crop_width, depth].
const Tensor& grads = context->input(0); const Tensor& grads = context->input(0);
OP_REQUIRES(context, grads.dims() == 4,
errors::InvalidArgument("grads image must be 4-D",
grads.shape().DebugString()));
const int crop_height = grads.dim_size(1);
const int crop_width = grads.dim_size(2);
OP_REQUIRES(context, crop_height > 0 && crop_width > 0,
errors::InvalidArgument("grads dimensions must be positive"));
// The shape of 'boxes' is [num_boxes, 4]. // The shape of 'boxes' is [num_boxes, 4].
const Tensor& boxes = context->input(1); const Tensor& boxes = context->input(1);
// The shape of 'box_index' is [num_boxes].
// The shape of 'box_ind' is [num_boxes]. const Tensor& box_index = context->input(2);
const Tensor& box_ind = context->input(2);
int num_boxes = 0;
ParseAndCheckBoxSizes(context, boxes, box_ind, &num_boxes);
OP_REQUIRES(
context, grads.dim_size(0) == num_boxes,
errors::InvalidArgument("boxes and grads have incompatible shape"));
// The shape of 'image_size' is [4]. // The shape of 'image_size' is [4].
const Tensor& image_size = context->input(3); const Tensor& image_size = context->input(3);
OP_REQUIRES(context, image_size.dims() == 1,
errors::InvalidArgument("image_size must be 1-D",
image_size.shape().DebugString()));
OP_REQUIRES(context, image_size.dim_size(0) == 4,
errors::InvalidArgument("image_size must have 4 elements",
image_size.shape().DebugString()));
// Validate input shapes.
OP_REQUIRES_ASYNC(context, grads.dims() == 4,
errors::InvalidArgument("grads image must be 4-D",
grads.shape().DebugString()),
done);
const int crop_height = grads.dim_size(1);
const int crop_width = grads.dim_size(2);
OP_REQUIRES_ASYNC(
context, crop_height > 0 && crop_width > 0,
errors::InvalidArgument("grads dimensions must be positive"), done);
int num_boxes = 0;
OP_REQUIRES_OK_ASYNC(
context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done);
OP_REQUIRES_ASYNC(
context, grads.dim_size(0) == num_boxes,
errors::InvalidArgument("boxes and grads have incompatible shape"),
done);
OP_REQUIRES_ASYNC(context, image_size.dims() == 1,
errors::InvalidArgument("image_size must be 1-D",
image_size.shape().DebugString()),
done);
OP_REQUIRES_ASYNC(context, image_size.dim_size(0) == 4,
errors::InvalidArgument("image_size must have 4 elements",
image_size.shape().DebugString()),
done);
auto image_size_vec = image_size.vec<int32>(); auto image_size_vec = image_size.vec<int32>();
const int batch = internal::SubtleMustCopy(image_size_vec(0)); const int batch_size = internal::SubtleMustCopy(image_size_vec(0));
const int image_height = internal::SubtleMustCopy(image_size_vec(1)); const int image_height = internal::SubtleMustCopy(image_size_vec(1));
const int image_width = internal::SubtleMustCopy(image_size_vec(2)); const int image_width = internal::SubtleMustCopy(image_size_vec(2));
const int depth = internal::SubtleMustCopy(image_size_vec(3)); const int depth = internal::SubtleMustCopy(image_size_vec(3));
OP_REQUIRES_ASYNC(
OP_REQUIRES(context, image_height > 0 && image_width > 0, context, image_height > 0 && image_width > 0,
errors::InvalidArgument("image dimensions must be positive")); errors::InvalidArgument("image dimensions must be positive"), done);
OP_REQUIRES( OP_REQUIRES_ASYNC(
context, grads.dim_size(3) == depth, context, grads.dim_size(3) == depth,
errors::InvalidArgument("image_size and grads are incompatible")); errors::InvalidArgument("image_size and grads are incompatible"), done);
// Allocate output tensor. // Allocate output tensor.
Tensor* output = nullptr; Tensor* output = nullptr;
OP_REQUIRES_OK( OP_REQUIRES_OK_ASYNC(
context, context->allocate_output( context,
0, TensorShape({batch, image_height, image_width, depth}), context->allocate_output(
&output)); 0, TensorShape({batch_size, image_height, image_width, depth}),
&output),
done);
typename TTypes<float, 4>::ConstTensor grads_data = auto compute_callback = [context, output]() {
grads.tensor<float, 4>(); const Tensor& grads = context->input(0);
typename TTypes<float, 2>::ConstTensor boxes_data = const Tensor& boxes = context->input(1);
boxes.tensor<float, 2>(); const Tensor& box_index = context->input(2);
typename TTypes<int32, 1>::ConstTensor box_ind_data = const bool status = functor::CropAndResizeBackpropImage<Device, T>()(
box_ind.tensor<int32, 1>(); context->eigen_device<Device>(), grads.tensor<float, 4>(),
typename TTypes<T, 4>::Tensor output_data = output->tensor<T, 4>(); boxes.tensor<float, 2>(), box_index.tensor<int32, 1>(),
output->tensor<T, 4>());
if (!status) {
context->SetStatus(errors::Internal(
"Failed launch CropAndResizeBackpropImage kernel."));
}
};
CheckValidBoxInd<Device>(context, box_ind_data, batch); RunIfBoxIndexIsValid<Device>(context, box_index.tensor<int32, 1>(),
batch_size, std::move(compute_callback),
bool status = functor::CropAndResizeBackpropImage<Device, T>()( std::move(done));
context->eigen_device<Device>(), grads_data, boxes_data, box_ind_data,
output_data);
if (!status) {
context->SetStatus(
errors::Internal("Failed launch CropAndResizeBackpropImageKernel."));
}
} }
}; };
@ -328,9 +373,9 @@ struct CropAndResizeBackpropImage<CPUDevice, T> {
bool operator()(const CPUDevice& d, bool operator()(const CPUDevice& d,
typename TTypes<float, 4>::ConstTensor grads, typename TTypes<float, 4>::ConstTensor grads,
typename TTypes<float, 2>::ConstTensor boxes, typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<int32, 1>::ConstTensor box_ind, typename TTypes<int32, 1>::ConstTensor box_index,
typename TTypes<T, 4>::Tensor grads_image) { typename TTypes<T, 4>::Tensor grads_image) {
const int batch = grads_image.dimension(0); const int batch_size = grads_image.dimension(0);
const int image_height = grads_image.dimension(1); const int image_height = grads_image.dimension(1);
const int image_width = grads_image.dimension(2); const int image_width = grads_image.dimension(2);
@ -347,8 +392,8 @@ struct CropAndResizeBackpropImage<CPUDevice, T> {
const float y2 = boxes(b, 2); const float y2 = boxes(b, 2);
const float x2 = boxes(b, 3); const float x2 = boxes(b, 3);
const int32 b_in = box_ind(b); const int32 b_in = box_index(b);
if (b_in < 0 || b_in >= batch) { if (!FastBoundsCheck(b_in, batch_size)) {
continue; continue;
} }
@ -399,83 +444,90 @@ struct CropAndResizeBackpropImage<CPUDevice, T> {
return true; return true;
} }
}; };
} // namespace functor } // namespace functor
template <typename Device, typename T> template <typename Device, typename T>
class CropAndResizeGradBoxesOp : public OpKernel { class CropAndResizeGradBoxesOp : public AsyncOpKernel {
public: public:
explicit CropAndResizeGradBoxesOp(OpKernelConstruction* context) explicit CropAndResizeGradBoxesOp(OpKernelConstruction* context)
: OpKernel(context) { : AsyncOpKernel(context) {
string method; string method;
OP_REQUIRES_OK(context, context->GetAttr("method", &method)); OP_REQUIRES_OK(context, context->GetAttr("method", &method));
OP_REQUIRES(context, method == "bilinear", OP_REQUIRES(context, method == "bilinear",
errors::InvalidArgument("method must be 'bilinear'", method)); errors::InvalidArgument("method must be 'bilinear'", method));
} }
void Compute(OpKernelContext* context) override { void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
// The shape of 'grads' is [num_boxes, crop_height, crop_width, depth]. // The shape of 'grads' is [num_boxes, crop_height, crop_width, depth].
const Tensor& grads = context->input(0); const Tensor& grads = context->input(0);
// The shape of 'boxes' is [num_boxes, 4].
const Tensor& boxes = context->input(2);
// The shape of 'box_index' is [num_boxes].
const Tensor& box_index = context->input(3);
// The shape of 'image' is [batch_size, image_height, image_width, depth].
const Tensor& image = context->input(1);
OP_REQUIRES(context, grads.dims() == 4, // Validate input shapes.
errors::InvalidArgument("grads image must be 4-D", OP_REQUIRES_ASYNC(context, grads.dims() == 4,
grads.shape().DebugString())); errors::InvalidArgument("grads image must be 4-D",
grads.shape().DebugString()),
done);
const int crop_height = grads.dim_size(1); const int crop_height = grads.dim_size(1);
const int crop_width = grads.dim_size(2); const int crop_width = grads.dim_size(2);
const int depth = grads.dim_size(3); const int depth = grads.dim_size(3);
OP_REQUIRES(context, crop_height > 0 && crop_width > 0, OP_REQUIRES_ASYNC(
errors::InvalidArgument("grads dimensions must be positive")); context, crop_height > 0 && crop_width > 0,
errors::InvalidArgument("grads dimensions must be positive"), done);
// The shape of 'image' is [batch, image_height, image_width, depth]. OP_REQUIRES_ASYNC(context, image.dims() == 4,
const Tensor& image = context->input(1); errors::InvalidArgument("input image must be 4-D",
OP_REQUIRES(context, image.dims() == 4, image.shape().DebugString()),
errors::InvalidArgument("input image must be 4-D", done);
image.shape().DebugString())); const int batch_size = image.dim_size(0);
const int batch = image.dim_size(0);
const int image_height = image.dim_size(1); const int image_height = image.dim_size(1);
const int image_width = image.dim_size(2); const int image_width = image.dim_size(2);
OP_REQUIRES(context, image_height > 0 && image_width > 0, OP_REQUIRES_ASYNC(
errors::InvalidArgument("image dimensions must be positive")); context, image_height > 0 && image_width > 0,
OP_REQUIRES(context, image.dim_size(3) == depth, errors::InvalidArgument("image dimensions must be positive"), done);
errors::InvalidArgument("image, grads depth differ")); OP_REQUIRES_ASYNC(context, image.dim_size(3) == depth,
errors::InvalidArgument("image, grads depth differ"),
// The shape of 'boxes' is [num_boxes, 4]. done);
const Tensor& boxes = context->input(2);
// The shape of 'box_ind' is [num_boxes].
const Tensor& box_ind = context->input(3);
int num_boxes = 0; int num_boxes = 0;
ParseAndCheckBoxSizes(context, boxes, box_ind, &num_boxes); OP_REQUIRES_OK_ASYNC(
context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done);
OP_REQUIRES( OP_REQUIRES_ASYNC(
context, grads.dim_size(0) == num_boxes, context, grads.dim_size(0) == num_boxes,
errors::InvalidArgument("boxes and grads have incompatible shape")); errors::InvalidArgument("boxes and grads have incompatible shape"),
done);
// Allocate output tensor. // Allocate output tensor.
Tensor* output = nullptr; Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output( OP_REQUIRES_OK_ASYNC(
0, TensorShape({num_boxes, 4}), &output)); context,
context->allocate_output(0, TensorShape({num_boxes, 4}), &output),
done);
typename TTypes<float, 4>::ConstTensor grads_data = auto compute_callback = [context, output]() {
grads.tensor<float, 4>(); const Tensor& grads = context->input(0);
typename TTypes<T, 4>::ConstTensor image_data = image.tensor<T, 4>(); const Tensor& image = context->input(1);
typename TTypes<float, 2>::ConstTensor boxes_data = const Tensor& boxes = context->input(2);
boxes.tensor<float, 2>(); const Tensor& box_index = context->input(3);
typename TTypes<int32, 1>::ConstTensor box_ind_data = const bool status = functor::CropAndResizeBackpropBoxes<Device, T>()(
box_ind.tensor<int32, 1>(); context->eigen_device<Device>(), grads.tensor<float, 4>(),
typename TTypes<float, 2>::Tensor output_data = output->tensor<float, 2>(); image.tensor<T, 4>(), boxes.tensor<float, 2>(),
box_index.tensor<int32, 1>(), output->tensor<float, 2>());
if (!status) {
context->SetStatus(errors::Internal(
"Failed launch CropAndResizeBackpropBoxes kernel."));
}
};
CheckValidBoxInd<Device>(context, box_ind_data, batch); RunIfBoxIndexIsValid<Device>(context, box_index.tensor<int32, 1>(),
batch_size, std::move(compute_callback),
bool status = functor::CropAndResizeBackpropBoxes<Device, T>()( std::move(done));
context->eigen_device<Device>(), grads_data, image_data, boxes_data,
box_ind_data, output_data);
if (!status) {
context->SetStatus(
errors::Internal("Failed launch CropAndResizeBackpropBoxesKernel."));
}
} }
}; };
@ -487,9 +539,9 @@ struct CropAndResizeBackpropBoxes<CPUDevice, T> {
typename TTypes<float, 4>::ConstTensor grads, typename TTypes<float, 4>::ConstTensor grads,
typename TTypes<T, 4>::ConstTensor image, typename TTypes<T, 4>::ConstTensor image,
typename TTypes<float, 2>::ConstTensor boxes, typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<int32, 1>::ConstTensor box_ind, typename TTypes<int32, 1>::ConstTensor box_index,
typename TTypes<float, 2>::Tensor grads_boxes) { typename TTypes<float, 2>::Tensor grads_boxes) {
const int batch = image.dimension(0); const int batch_size = image.dimension(0);
const int image_height = image.dimension(1); const int image_height = image.dimension(1);
const int image_width = image.dimension(2); const int image_width = image.dimension(2);
@ -506,8 +558,8 @@ struct CropAndResizeBackpropBoxes<CPUDevice, T> {
const float y2 = boxes(b, 2); const float y2 = boxes(b, 2);
const float x2 = boxes(b, 3); const float x2 = boxes(b, 3);
const int32 b_in = box_ind(b); const int32 b_in = box_index(b);
if (b_in < 0 || b_in >= batch) { if (!FastBoundsCheck(b_in, batch_size)) {
continue; continue;
} }
@ -589,30 +641,19 @@ struct CropAndResizeBackpropBoxes<CPUDevice, T> {
return true; return true;
} }
}; };
} // namespace functor } // namespace functor
// Specialization of CheckValidBoxInd for a CPUDevice. #define REGISTER_KERNEL(T) \
template <> REGISTER_KERNEL_BUILDER(Name("CropAndResize") \
inline void CheckValidBoxInd<CPUDevice>( .Device(DEVICE_CPU) \
OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_ind, .TypeConstraint<T>("T") \
int batch) { .HostMemory("crop_size"), \
const int num_boxes = box_ind.dimension(0); CropAndResizeOp<CPUDevice, T>); \
for (int b = 0; b < num_boxes; ++b) { \
OP_REQUIRES(context, box_ind(b) >= 0 && box_ind(b) < batch, REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes") \
errors::OutOfRange("box_ind has values outside [0, batch)")); .Device(DEVICE_CPU) \
} .TypeConstraint<T>("T"), \
}
#define REGISTER_KERNEL(T) \
REGISTER_KERNEL_BUILDER(Name("CropAndResize") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.HostMemory("crop_size"), \
CropAndResizeOp<CPUDevice, T>); \
\
REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T"), \
CropAndResizeGradBoxesOp<CPUDevice, T>); CropAndResizeGradBoxesOp<CPUDevice, T>);
TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL); TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
@ -634,50 +675,86 @@ TF_CALL_double(REGISTER_KERNEL);
#if GOOGLE_CUDA #if GOOGLE_CUDA
// Forward declaration of the CheckValidBoxIndHelper specialization for GPU. // Forward declaration of the CheckValidBoxIndexHelper specialization for GPU.
namespace functor { namespace functor {
template <> template <>
void CheckValidBoxIndHelper<GPUDevice>::operator()( void CheckValidBoxIndexHelper<GPUDevice>::operator()(
const GPUDevice& d, typename TTypes<int32, 1>::ConstTensor box_ind, const GPUDevice& d, typename TTypes<int32, 1>::ConstTensor box_index,
int batch, typename TTypes<bool, 0>::Tensor isvalid); int batch_size, typename TTypes<bool, 0>::Tensor isvalid);
extern template struct CheckValidBoxIndHelper<GPUDevice>; extern template struct CheckValidBoxIndexHelper<GPUDevice>;
} // namespace functor } // namespace functor
// Specialization of CheckValidBoxInd for a GPUDevice. namespace {
// Specialization of CheckValidBoxIndex for a GPUDevice.
template <> template <>
inline void CheckValidBoxInd<GPUDevice>( inline void RunIfBoxIndexIsValid<GPUDevice>(
OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_ind, OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index,
int batch) { int batch_size, Callback compute, Callback done) {
const int num_boxes = box_ind.dimension(0); const int num_boxes = box_index.dimension(0);
if (num_boxes == 0) { if (num_boxes == 0) {
compute();
done();
return; return;
} }
Tensor isvalid_tensor;
OP_REQUIRES_OK(context,
context->allocate_temp(DataTypeToEnum<bool>::value,
TensorShape({}), &isvalid_tensor));
typename TTypes<bool, 0>::Tensor isvalid = isvalid_tensor.tensor<bool, 0>(); Tensor isvalid_dev_tensor;
OP_REQUIRES_OK_ASYNC(
context,
context->allocate_temp(DataTypeToEnum<bool>::value, TensorShape({}),
&isvalid_dev_tensor),
done);
typename TTypes<bool, 0>::Tensor isvalid_dev =
isvalid_dev_tensor.tensor<bool, 0>();
functor::CheckValidBoxIndHelper<GPUDevice>()( // Run the actual box check on the device.
context->eigen_device<GPUDevice>(), box_ind, batch, isvalid); functor::CheckValidBoxIndexHelper<GPUDevice>()(
context->eigen_device<GPUDevice>(), box_index, batch_size, isvalid_dev);
// Copy the result back to the host.
auto* stream = context->op_device_context()->stream(); auto* stream = context->op_device_context()->stream();
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); OP_REQUIRES_ASYNC(context, stream,
errors::Internal("No GPU stream available."), done);
Tensor isvalid_host_tensor;
// Use pinned host memory on the host to avoid unnecessary
// synchronization.
AllocatorAttributes alloc_attr;
alloc_attr.set_on_host(true);
alloc_attr.set_gpu_compatible(true);
OP_REQUIRES_OK_ASYNC(
context,
context->allocate_temp(DataTypeToEnum<bool>::value, TensorShape({}),
&isvalid_host_tensor, alloc_attr),
done);
typename TTypes<bool, 0>::Tensor isvalid_host =
isvalid_host_tensor.tensor<bool, 0>();
bool isvalid_host = false; perftools::gputools::DeviceMemoryBase wrapped(isvalid_dev.data(),
perftools::gputools::DeviceMemoryBase isvalid_gpu(isvalid.data(), sizeof(bool));
sizeof(bool)); const bool status = stream
stream->ThenMemcpy(&isvalid_host, isvalid_gpu, sizeof(bool)); ->ThenMemcpy(isvalid_host.data() /* destination */,
stream->BlockHostUntilDone(); wrapped /* source */, sizeof(bool))
.ok();
OP_REQUIRES_ASYNC(
context, status,
errors::Internal("Failed to launch copy of isvalid from device to host."),
done);
OP_REQUIRES(context, stream->ok(), auto wrapped_callback = [context, isvalid_host, compute, done]() {
errors::Internal("cudaMemcpy from device to host failed")); OP_REQUIRES_ASYNC(
context, isvalid_host(),
errors::OutOfRange("box_index has values outside [0, batch_size)"),
done);
compute();
done();
};
OP_REQUIRES(context, isvalid_host, context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
errors::OutOfRange("box_ind has values outside [0, batch)")); stream, wrapped_callback);
} }
} // namespace
#define REGISTER_KERNEL(T) \ #define REGISTER_KERNEL(T) \
REGISTER_KERNEL_BUILDER(Name("CropAndResize") \ REGISTER_KERNEL_BUILDER(Name("CropAndResize") \
.Device(DEVICE_GPU) \ .Device(DEVICE_GPU) \

View File

@ -53,12 +53,12 @@ struct CropAndResizeBackpropBoxes {
}; };
template <typename Device> template <typename Device>
struct CheckValidBoxIndHelper { struct CheckValidBoxIndexHelper {
// Checks if all values in box_ind are in [0, batch). // Checks if all values in box_index are in [0, batch).
void operator()(const Device& d, void operator()(const Device& d,
typename TTypes<int32, 1>::ConstTensor box_ind, int batch, typename TTypes<int32, 1>::ConstTensor box_index, int batch,
typename TTypes<bool, 0>::Tensor isvalid) { typename TTypes<bool, 0>::Tensor isvalid) {
isvalid.device(d) = ((box_ind >= 0) && (box_ind < batch)).all(); isvalid.device(d) = ((box_index >= 0) && (box_index < batch)).all();
} }
}; };

View File

@ -440,7 +440,7 @@ TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
#undef DEFINE_GPU_SPECS #undef DEFINE_GPU_SPECS
template struct CheckValidBoxIndHelper<GPUDevice>; template struct CheckValidBoxIndexHelper<GPUDevice>;
} // namespace functor } // namespace functor
} // namespace tensorflow } // namespace tensorflow

View File

@ -251,7 +251,7 @@ TEST_F(CropAndResizeOpTest, TestInvalidBoxIndexShape) {
Status s = RunOpKernel(); Status s = RunOpKernel();
ASSERT_FALSE(s.ok()); ASSERT_FALSE(s.ok());
EXPECT_TRUE( EXPECT_TRUE(
StringPiece(s.ToString()).contains("box_ind has incompatible shape")) StringPiece(s.ToString()).contains("box_index has incompatible shape"))
<< s; << s;
} }
@ -264,8 +264,10 @@ TEST_F(CropAndResizeOpTest, TestInvalidBoxIndex) {
Status s = RunOpKernel(); Status s = RunOpKernel();
ASSERT_FALSE(s.ok()); ASSERT_FALSE(s.ok());
EXPECT_TRUE(StringPiece(s.ToString()) EXPECT_TRUE(StringPiece(s.ToString())
.contains("box_ind has values outside [0, batch)")) .contains("box_index has values outside [0, batch_size)"))
<< s; << s;
} }
// TODO(zhengxq, rmlarsen): Add a benchmark.
} // namespace tensorflow } // namespace tensorflow

View File

@ -64,8 +64,8 @@ class LookupTableOp : public OpKernel {
return ctx->status(); return ctx->status();
} }
if (ctx->track_allocations()) { if (ctx->track_allocations()) {
ctx->record_device_persistent_memory_allocation( ctx->record_host_persistent_memory_allocation(
container->MemoryUsed()); container->MemoryUsed() + table_handle_.AllocatedBytes());
} }
*ret = container; *ret = container;
return Status::OK(); return Status::OK();
@ -225,6 +225,15 @@ class HashTable : public InitializableLookupTable {
return Status::OK(); return Status::OK();
} }
int64 MemoryUsed() const override {
if (table_) {
const int64 num_elements = table_->size();
return num_elements * (sizeof(K) + sizeof(V));
} else {
return 0;
}
}
private: private:
std::unique_ptr<std::unordered_map<K, V>> table_; std::unique_ptr<std::unordered_map<K, V>> table_;
}; };

View File

@ -623,7 +623,17 @@ REGISTER_OP("QueueDequeueV2")
.Output("components: component_types") .Output("components: component_types")
.Attr("component_types: list(type) >= 1") .Attr("component_types: list(type) >= 1")
.Attr("timeout_ms: int = -1") .Attr("timeout_ms: int = -1")
.SetShapeFn(shape_inference::UnknownShape) .SetShapeFn([](InferenceContext* c) {
if (c->num_outputs() == 1) {
c->set_output(0, c->input_handle_shape(0));
} else {
// TODO(vrv): handle the case of multiple outputs.
for (int i = 0; i < c->num_outputs(); ++i) {
c->set_output(i, c->UnknownShape());
}
}
return Status::OK();
})
.Doc(R"doc( .Doc(R"doc(
Dequeues a tuple of one or more tensors from the given queue. Dequeues a tuple of one or more tensors from the given queue.

View File

@ -58,6 +58,22 @@ cc_library(
], ],
) )
# Dummy stream executor cuda plugins.
cc_library(
name = "cublas_plugin",
srcs = [],
)
cc_library(
name = "cufft_plugin",
srcs = [],
)
cc_library(
name = "cudnn_plugin",
srcs = [],
)
# OSX framework for device driver access # OSX framework for device driver access
cc_library( cc_library(
name = "IOKit", name = "IOKit",

View File

@ -137,16 +137,16 @@ which to operate must always be given explicitly. This is the reason why
## Module: reroute ## Module: reroute
* @{tf.contrib.graph_editor.reroute.swap_ts} * @{tf.contrib.graph_editor.swap_ts}
* @{tf.contrib.graph_editor.reroute.reroute_ts} * @{tf.contrib.graph_editor.reroute_ts}
* @{tf.contrib.graph_editor.reroute.swap_inputs} * @{tf.contrib.graph_editor.swap_inputs}
* @{tf.contrib.graph_editor.reroute.reroute_inputs} * @{tf.contrib.graph_editor.reroute_inputs}
* @{tf.contrib.graph_editor.reroute.swap_outputs} * @{tf.contrib.graph_editor.swap_outputs}
* @{tf.contrib.graph_editor.reroute.reroute_outputs} * @{tf.contrib.graph_editor.reroute_outputs}
* @{tf.contrib.graph_editor.reroute.swap_ios} * @{tf.contrib.graph_editor.swap_ios}
* @{tf.contrib.graph_editor.reroute.reroute_ios} * @{tf.contrib.graph_editor.reroute_ios}
* @{tf.contrib.graph_editor.reroute.remove_control_inputs} * @{tf.contrib.graph_editor.remove_control_inputs}
* @{tf.contrib.graph_editor.reroute.add_control_inputs} * @{tf.contrib.graph_editor.add_control_inputs}
## Module: edit ## Module: edit

View File

@ -21,7 +21,7 @@ Subclasses of `LinearOperator` provide a access to common methods on a
* @{tf.contrib.linalg.LinearOperatorDiag} * @{tf.contrib.linalg.LinearOperatorDiag}
* @{tf.contrib.linalg.LinearOperatorIdentity} * @{tf.contrib.linalg.LinearOperatorIdentity}
* @{tf.contrib.linalg.LinearOperatorScaledIdentity} * @{tf.contrib.linalg.LinearOperatorScaledIdentity}
* @{tf.contrib.linalg.LinearOperatorMatrix} * @{tf.contrib.linalg.LinearOperatorFullMatrix}
* @{tf.contrib.linalg.LinearOperatorTriL} * @{tf.contrib.linalg.LinearOperatorTriL}
* @{tf.contrib.linalg.LinearOperatorUDVHUpdate} * @{tf.contrib.linalg.LinearOperatorUDVHUpdate}

View File

@ -13,8 +13,8 @@ of samples in the batch and `d1` ... `dN` are the remaining dimensions.
It is common, when training with multiple loss functions, to adjust the relative It is common, when training with multiple loss functions, to adjust the relative
strengths of individual losses. This is performed by rescaling the losses via strengths of individual losses. This is performed by rescaling the losses via
a `weight` parameter passed to the loss functions. For example, if we were a `weight` parameter passed to the loss functions. For example, if we were
training with both log_loss and sum_of_squares_loss, and we wished that the training with both log_loss and mean_square_error, and we wished that the
log_loss penalty be twice as severe as the sum_of_squares_loss, we would log_loss penalty be twice as severe as the mean_square_error, we would
implement this as: implement this as:
```python ```python
@ -22,7 +22,7 @@ implement this as:
tf.contrib.losses.log(predictions, labels, weight=2.0) tf.contrib.losses.log(predictions, labels, weight=2.0)
# Uses default weight of 1.0 # Uses default weight of 1.0
tf.contrib.losses.sum_of_squares(predictions, labels) tf.contrib.losses.mean_square_error(predictions, labels)
# All the losses are collected into the `GraphKeys.LOSSES` collection. # All the losses are collected into the `GraphKeys.LOSSES` collection.
losses = tf.get_collection(tf.GraphKeys.LOSSES) losses = tf.get_collection(tf.GraphKeys.LOSSES)
@ -74,7 +74,7 @@ these predictions.
predictions = MyModelPredictions(images) predictions = MyModelPredictions(images)
weight = tf.cast(tf.greater(depths, 0), tf.float32) weight = tf.cast(tf.greater(depths, 0), tf.float32)
loss = tf.contrib.losses.sum_of_squares(predictions, depths, weight) loss = tf.contrib.losses.mean_square_error(predictions, depths, weight)
``` ```
Note that when using weights for the losses, the final average is computed Note that when using weights for the losses, the final average is computed
@ -100,7 +100,7 @@ weighted average over the individual prediction errors:
weight = MyComplicatedWeightingFunction(labels) weight = MyComplicatedWeightingFunction(labels)
weight = tf.div(weight, tf.size(weight)) weight = tf.div(weight, tf.size(weight))
loss = tf.contrib.losses.sum_of_squares(predictions, depths, weight) loss = tf.contrib.losses.mean_square_error(predictions, depths, weight)
``` ```
@{tf.contrib.losses.absolute_difference} @{tf.contrib.losses.absolute_difference}
@ -118,9 +118,4 @@ weighted average over the individual prediction errors:
@{tf.contrib.losses.softmax_cross_entropy} @{tf.contrib.losses.softmax_cross_entropy}
@{tf.contrib.losses.sparse_softmax_cross_entropy} @{tf.contrib.losses.sparse_softmax_cross_entropy}
The following are deprecated in favor of `mean_pairwise_squared_error` and
`mean_squared_error`.
@{tf.contrib.losses.sum_of_pairwise_squares}
@{tf.contrib.losses.sum_of_squares}

View File

@ -278,7 +278,7 @@ Then, the code creates a `DNNClassifier` model using the following arguments:
The `tf.contrib.learn` API uses input functions, which create the TensorFlow The `tf.contrib.learn` API uses input functions, which create the TensorFlow
operations that generate data for the model. In this case, the data is small operations that generate data for the model. In this case, the data is small
enough that it can be stored in @{tf.constant TensorFlow constants}. The enough that it can be stored in @{tf.constant$TensorFlow constants}. The
following code produces the simplest possible input pipeline: following code produces the simplest possible input pipeline:
```python ```python

View File

@ -211,15 +211,20 @@ two files are available to the JVM:
* the downloaded `.jar` file * the downloaded `.jar` file
* the extracted JNI library * the extracted JNI library
For example, the following command line executes the `HelloTF` program: For example, the following command line executes the `HelloTF` program on Linux
and Mac OS X:
<pre><b>java -cp libtensorflow-1.1.0.jar:. -Djava.library.path=./jni HelloTF</b></pre> <pre><b>java -cp libtensorflow-1.1.0.jar:. -Djava.library.path=./jni HelloTF</b></pre>
And the following comand line executes the `HelloTF` program on Windows:
<pre><b>java -cp libtensorflow-1.1.0-rc2.jar;. -Djava.library.path=jni HelloTF</b></pre>
If the program prints <tt>Hello from <i>version</i></tt>, you've successfully If the program prints <tt>Hello from <i>version</i></tt>, you've successfully
installed TensorFlow for Java and are ready to use the API. If the program installed TensorFlow for Java and are ready to use the API. If the program
outputs something else, check outputs something else, check
[Stack Overflow](http://stackoverflow.com/questions/tagged/tensorflow) [Stack Overflow](http://stackoverflow.com/questions/tagged/tensorflow) for
for possible solutions. possible solutions.
### Advanced Example ### Advanced Example

View File

@ -1,17 +1,17 @@
# TensorFlow Performance Benchmarks # Benchmarks
## Overview ## Overview
A selection of image classification models were tested across multiple platforms A selection of image classification models were tested across multiple platforms
to create a point of reference for the TensorFlow community. The methodology, to create a point of reference for the TensorFlow community. The methodology,
links to the scripts, and commands to reproduce the results are in the links to the benchmark scripts, and commands to reproduce the results are in the
[appendix](#appendix). [Appendix](#appendix).
## Results for image classification models ## Results for image classification models
InceptionV3 ([arXiv:1512.00567](https://arxiv.org/abs/1512.00567)), InceptionV3 ([arXiv:1512.00567](https://arxiv.org/abs/1512.00567)), ResNet-50
ResNet-50 ([arXiv:1512.03385](https://arxiv.org/abs/1512.03385)), ([arXiv:1512.03385](https://arxiv.org/abs/1512.03385)), ResNet-152
ResNet-152 ([arXiv:1512.03385](https://arxiv.org/abs/1512.03385)), VGG16 ([arXiv:1512.03385](https://arxiv.org/abs/1512.03385)), VGG16
([arXiv:1409.1556](https://arxiv.org/abs/1409.1556)), and ([arXiv:1409.1556](https://arxiv.org/abs/1409.1556)), and
[AlexNet](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf) [AlexNet](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf)
were tested using the [ImageNet](http://www.image-net.org/) data set. Tests were were tested using the [ImageNet](http://www.image-net.org/) data set. Tests were
@ -27,32 +27,32 @@ input pipeline and the underlying disk I/O are saturating the compute units.
### Training with NVIDIA® DGX-1™ (NVIDIA® Tesla® P100) ### Training with NVIDIA® DGX-1™ (NVIDIA® Tesla® P100)
<div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px;"> <div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="../images/perf_summary_p100_single_server.png"> <img style="width:80%" src="../images/perf_summary_p100_single_server.png">
</div> </div>
Details and additional results are in the [Details for NVIDIA® DGX-1™ (NVIDIA® Details and additional results are in the [Details for NVIDIA® DGX-1™ (NVIDIA®
Tesla® P100)](#details-for-nvidia®-dgx-1™-nvidia®-tesla®-p100) section. Tesla® P100)](#details_for_nvidia_dgx-1tm_nvidia_tesla_p100) section.
### Training with NVIDIA® Tesla® K80 ### Training with NVIDIA® Tesla® K80
<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> <div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="../images/perf_summary_k80_single_server.png"> <img style="width:80%" src="../images/perf_summary_k80_single_server.png">
</div> </div>
Details and additional results are in the [Details for Google Compute Engine Details and additional results are in the [Details for Google Compute Engine
(NVIDIA® Tesla® K80)](#details-for-google-compute-engine-nvidia®-tesla®-k80) and (NVIDIA® Tesla® K80)](#details_for_google_compute_engine_nvidia_tesla_k80) and
[Details for Amazon EC2 (NVIDIA® Tesla® [Details for Amazon EC2 (NVIDIA® Tesla®
K80)](#details-for-amazon-ec2-nvidia®-tesla®-k80) sections. K80)](#details_for_amazon_ec2_nvidia_tesla_k80) sections.
### Distributed training with NVIDIA® Tesla® K80 ### Distributed training with NVIDIA® Tesla® K80
<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> <div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="../images/perf_summary_k80_aws_distributed.png"> <img style="width:80%" src="../images/perf_summary_k80_aws_distributed.png">
</div> </div>
Details and additional results are in the [Details for Amazon EC2 Distributed Details and additional results are in the [Details for Amazon EC2 Distributed
(NVIDIA® Tesla® K80)](#details-for-amazon-ec2-distributed-nvidia®-tesla®-k80) (NVIDIA® Tesla® K80)](#details_for_amazon_ec2_distributed_nvidia_tesla_k80)
section. section.
### Compare synthetic with real data training ### Compare synthetic with real data training
@ -82,12 +82,15 @@ section.
* **TensorFlow GitHub hash:** b1e174e * **TensorFlow GitHub hash:** b1e174e
* **Build Command:** `bazel build -c opt --copt=-march="haswell" --config=cuda * **Build Command:** `bazel build -c opt --copt=-march="haswell" --config=cuda
//tensorflow/tools/pip_package:build_pip_package` //tensorflow/tools/pip_package:build_pip_package`
* **Disk:** local SSD * **Disk:** Local SSD
* **DataSet:** ImageNet * **DataSet:** ImageNet
Batch size and optimizer used for each model. Batch size and optimizer used for each model are listed in the table below. In
addition to the batch sizes listed in the table, InceptionV3, ResNet-50,
ResNet-152, and VGG16 were tested with a batch size of 32. Those results are in
the *other results* section.
| InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16 Options | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
------------------ | ----------- | --------- | ---------- | ------- | ----- ------------------ | ----------- | --------- | ---------- | ------- | -----
Batch size per GPU | 64 | 64 | 64 | 512 | 64 Batch size per GPU | 64 | 64 | 64 | 512 | 64
Optimizer | sgd | sgd | sgd | sgd | sgd Optimizer | sgd | sgd | sgd | sgd | sgd
@ -104,10 +107,8 @@ VGG16 | replicated (with NCCL) | n/a
### Results ### Results
Batch size and optimizer used for each model are listed in the table below.
<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> <div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="../images/perf_summary_p100_single_server.png"> <img style="width:80%" src="../images/perf_summary_p100_single_server.png">
</div> </div>
<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> <div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
@ -136,6 +137,28 @@ GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
Training AlexNet with real data on 8 GPUs was excluded from the graph and table Training AlexNet with real data on 8 GPUs was excluded from the graph and table
above due to it maxing out the input pipeline. above due to it maxing out the input pipeline.
### Other Results
The results below are all with a batch size of 32.
**Training synthetic data**
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | VGG16
---- | ----------- | --------- | ---------- | -----
1 | 128 | 210 | 85.3 | 124
2 | 259 | 412 | 166 | 241
4 | 520 | 827 | 330 | 470
8 | 995 | 1623 | 643 | 738
**Training real data**
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | VGG16
---- | ----------- | --------- | ---------- | -----
1 | 130 | 208 | 85.0 | 124
2 | 257 | 403 | 163 | 221
4 | 507 | 814 | 325 | 401
8 | 966 | 1525 | 641 | 619
## Details for Google Compute Engine (NVIDIA® Tesla® K80) ## Details for Google Compute Engine (NVIDIA® Tesla® K80)
### Environment ### Environment
@ -156,7 +179,7 @@ addition to the batch sizes listed in the table, InceptionV3 and ResNet-50 were
tested with a batch size of 32. Those results are in the *other results* tested with a batch size of 32. Those results are in the *other results*
section. section.
| InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16 Options | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
------------------ | ----------- | --------- | ---------- | ------- | ----- ------------------ | ----------- | --------- | ---------- | ------- | -----
Batch size per GPU | 64 | 64 | 32 | 512 | 32 Batch size per GPU | 64 | 64 | 32 | 512 | 32
Optimizer | sgd | sgd | sgd | sgd | sgd Optimizer | sgd | sgd | sgd | sgd | sgd
@ -184,10 +207,10 @@ GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16 GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
---- | ----------- | --------- | ---------- | ------- | ----- ---- | ----------- | --------- | ---------- | ------- | -----
1 | 30.5 | 56.7 | 20.7 | 639 | 30.2 1 | 30.6 | 56.7 | 20.7 | 639 | 30.2
2 | 57.8 | 107 | 39 | 1136 | 55.5 2 | 58.4 | 107 | 39.0 | 1136 | 55.5
4 | 115 | 211 | 77.3 | 2067 | 106 4 | 115 | 211 | 77.3 | 2067 | 106
8 | 225 | 418 | 150 | 4056 | 213 8 | 225 | 422 | 151 | 4056 | 213
### Other Results ### Other Results
@ -204,10 +227,10 @@ GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32) GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
---- | --------------------------- | ------------------------- ---- | --------------------------- | -------------------------
1 | 29.3 | 53.6 1 | 29.5 | 53.6
2 | 55 | 102 2 | 55.4 | 102
4 | 109 | 200 4 | 110 | 201
8 | 215 | 387 8 | 216 | 387
## Details for Amazon EC2 (NVIDIA® Tesla® K80) ## Details for Amazon EC2 (NVIDIA® Tesla® K80)
@ -230,7 +253,7 @@ addition to the batch sizes listed in the table, InceptionV3 and ResNet-50 were
tested with a batch size of 32. Those results are in the *other results* tested with a batch size of 32. Those results are in the *other results*
section. section.
| InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16 Options | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
------------------ | ----------- | --------- | ---------- | ------- | ----- ------------------ | ----------- | --------- | ---------- | ------- | -----
Batch size per GPU | 64 | 64 | 32 | 512 | 32 Batch size per GPU | 64 | 64 | 32 | 512 | 32
Optimizer | sgd | sgd | sgd | sgd | sgd Optimizer | sgd | sgd | sgd | sgd | sgd
@ -289,7 +312,7 @@ GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32) GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
---- | --------------------------- | ------------------------- ---- | --------------------------- | -------------------------
1 | 30.0 | 53.6 1 | 30.0 | 53.6
2 | 57.5 | 101 2 | 57.5 | 102
4 | 113 | 202 4 | 113 | 202
8 | 212 | 379 8 | 212 | 379
@ -313,7 +336,7 @@ addition to the batch sizes listed in the table, InceptionV3 and ResNet-50 were
tested with a batch size of 32. Those results are in the *other results* tested with a batch size of 32. Those results are in the *other results*
section. section.
| InceptionV3 | ResNet-50 | ResNet-152 Options | InceptionV3 | ResNet-50 | ResNet-152
------------------ | ----------- | --------- | ---------- ------------------ | ----------- | --------- | ----------
Batch size per GPU | 64 | 64 | 32 Batch size per GPU | 64 | 64 | 32
Optimizer | sgd | sgd | sgd Optimizer | sgd | sgd | sgd
@ -337,7 +360,7 @@ used with the following exceptions:
### Results ### Results
<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> <div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:95%" src="../images/perf_summary_k80_aws_distributed.png"> <img style="width:80%" src="../images/perf_summary_k80_aws_distributed.png">
</div> </div>
<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> <div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
@ -374,34 +397,37 @@ GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
### Executing benchmark tests ### Executing benchmark tests
The code for the benchmarks was created to be both used for benchmarking The [benchmark code](https://github.com/tensorflow/benchmarks/tree/master/scripts/tf_cnn_benchmarks)
TensorFlow as well as used as a tool to test hardware platforms. The benchmark was created to be used for benchmarking TensorFlow as well as used as a tool to
code includes modes such as `trivial` that run a virtually empty model that is test hardware platforms. Techniques used in the benchmark scripts are detailed
useful for testing the maximum possibly samples/sec for the input pipeline among in @{$performance_models$High-Performance Models}.
other things. Not only does this test TensorFlow but also the throughput of the
underlying systems. There are two ways to execute the benchmarks in
[tf_cnn_benchmarks.py](TODO: LINK TO GITHUB):
1. Execute [tf_cnn_benchmarks.py](TODO: LINK TO GITHUB) directly There are two ways to execute the benchmark code:
2. Utilize the [small wrapper](TODO: LINK TO GITHUB) that helps pick the
correct config 1. Execute [tf_cnn_benchmarks.py](https://github.com/tensorflow/benchmarks/tree/master/scripts/tf_cnn_benchmarks/tf_cnn_benchmarks.py)
directly.
2. Utilize the [scripts](https://github.com/tensorflow/benchmarks/tree/master/scripts/tf_cnn_benchmarks/main.py)
that helps pick the correct config for each platform executes
`tf_cnn_benchmarks.py`.
The wrapper is suggested as a starting point. Then investigate the variety of The wrapper is suggested as a starting point. Then investigate the variety of
options available in `tf_cnn_benchmarks.py`. While the wrapper extensive options available in `tf_cnn_benchmarks.py`. Below are a couple examples of
examples, below are a couple highlights. using the wrapper.
Run ResNet-50 on a single instance with 8 GPUs. The `system` argument is used to **Single Server**
determine the optimal configuration. The supported values are gce, aws, and This example illustrates training ResNet-50 on a single instance with 8 GPUs.
dgx1. If `system` is not passeed, the best config for the most widely available The `system` flag is used to determine the optimal configuration. The
hardware is used. supported values are gce, aws, and dgx1. If `system` is not passed, the best
config for the most widely available hardware is used.
```bash ```bash
python main.py --model=resnet50 --num_gpus=8 python main.py --model=resnet50 --num_gpus=8
python main.py --system=aws --model=resnet50 --num_gpus=8 python main.py --system=aws --model=resnet50 --num_gpus=8
``` ```
Run ResNet-50 on 2 hosts, e.g. host_0 (10.0.0.1) and host_1 (10.0.0.2), with 8 **Distributed**
GPUs each on aws. This example illustrates training ResNet-50 on 2 hosts, e.g. host_0 (10.0.0.1)
and host_1 (10.0.0.2), with 8 GPUs each on AWS (Amazon EC2).
```bash ```bash
# Run the following commands on host_0 (10.0.0.1): # Run the following commands on host_0 (10.0.0.1):

View File

@ -2,11 +2,19 @@
Performance is often a significant issue when training a machine learning Performance is often a significant issue when training a machine learning
model. This section explains various ways to optimize performance. Start model. This section explains various ways to optimize performance. Start
your investigation with the following guide: your investigation with the @{$performance_guide$Performance Guide} and then go
deeper with techniques detailed in @{$performance_models$High-Performance Models}:
* @{$performance_guide$Performance}, which contains a collection of best * @{$performance_guide$Performance Guide}, which contains a collection of best
practices for optimizing your TensorFlow code. practices for optimizing your TensorFlow code.
* @{$performance_models$High-Performance Models}, which contains a collection
advanced techniques to build highly scalable models targeting different
system types and network topologies.
* @{$benchmarks$Benchmarks}, which contains a collection of benchmark
results.
XLA (Accelerated Linear Algebra) is an experimental compiler for linear XLA (Accelerated Linear Algebra) is an experimental compiler for linear
algebra that optimizes TensorFlow computations. The following guides explore algebra that optimizes TensorFlow computations. The following guides explore
XLA: XLA:

View File

@ -1,4 +1,8 @@
performance_guide.md performance_guide.md
performance_models.md
benchmarks.md
quantization.md
>>>
xla/index.md xla/index.md
xla/broadcasting.md xla/broadcasting.md
xla/developing_new_backend.md xla/developing_new_backend.md
@ -6,4 +10,3 @@ xla/jit.md
xla/operation_semantics.md xla/operation_semantics.md
xla/shapes.md xla/shapes.md
xla/tfcompile.md xla/tfcompile.md
quantization.md

View File

@ -1,8 +1,10 @@
# Performance # Performance Guide
This guide contains a collection of best practices for optimizing your This guide contains a collection of best practices for optimizing your
TensorFlow code. The best practices apply to both new and experienced TensorFlow code. The best practices apply to both new and experienced
Tensorflow users. Tensorflow users. As a complement to the best practices in this document, the
@{$performance_models$High-Performance Models} document links to example code
and details for creating models that scale on a variety of hardware.
## Best Practices ## Best Practices
While optimizing implementations of different types of models can be different, While optimizing implementations of different types of models can be different,
@ -73,7 +75,7 @@ Unless for a special circumstance or for example code, do not feed data
into the session from Python variables, e.g. `dictionary`. into the session from Python variables, e.g. `dictionary`.
```python ```python
# This will result in poor performance. # Using feed_dict often results in suboptimal performance when using large inputs.
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
``` ```
@ -141,3 +143,4 @@ bn = tf.contrib.layers.batch_norm(
The non-fused batch norm does computations using several individual Ops. Fused The non-fused batch norm does computations using several individual Ops. Fused
batch norm combines the individual operations into a single kernel, which runs batch norm combines the individual operations into a single kernel, which runs
faster. faster.

View File

@ -1,155 +1,109 @@
# High-Performance Models # High-Performance Models
TensorFlow is a powerful and flexible machine learning platform. This document and accompanying
It can be used to distribute model training and inference across a large number [scripts](https://github.com/tensorflow/benchmarks/tree/master/scripts/tf_cnn_benchmarks)
of machines and computation devices. detail how to build highly scalable models that target a variety of system types
and network topologies. The techniques in this document utilize some low-level
Its software stack is made of a few layers: TensorFlow Python primitives. In the future, many of these techniques will be
incorporated into high-level APIs.
* a fast and powerful C++ core
* low-level Python primitives that sit right above individual kernels
* a diverse range of high-level libraries that aim to make building real models
easier
There are many existing examples and tutorials that explain useful features in
TensorFlow. The goal of this set of scripts is to demonstrate that we can build
flexible and powerful high-performance models using the low-level APIs.
In the future, many of the high-performance primitives will be incorporated into
high-level APIs, and made available to more users transparently.
But meanwhile, we show that it is fairly easy for advanced users to build highly
scalable models targeting different system types, network topologies, etc.
We divide our effort to build high-performance models into three categories:
1. A fast input pipeline to read data from disk, preprocess it, and make it
ready on the GPU.
2. A high-throughput model that trains on GPU very efficiently.
3. Fast variable and gradients distribution mechanisms that scale well across
many machines and computation devices.
## Input Pipeline ## Input Pipeline
The input pipeline is the part of a TensorFlow program that reads input data, The @{$performance_guide$Performance Guide} explains how to identify possible
shuffles it, and preprocesses it. input pipeline issues and best practices. We found that using @{tf.FIFOQueue}
and @{tf.train.queue_runner} could not saturate multiple current generation GPUs
when using large inputs and processing with higher samples per second, such
as training ImageNet with [AlexNet](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf).
This is due to the the use of Python threads as its underlying implementation.
The overhead of Python threads is too large.
Among the most important features to build a fast input pipeline: Another approach, which we have implemented in the
[scripts](https://github.com/tensorflow/benchmarks/tree/master/scripts/tf_cnn_benchmarks),
is to build an input pipeline using the native parallelism in TensorFlow. Our
implementation is made up of 3 stages:
* Avoid using feed-dictionary to feed a large amount of data for each step. * I/O reads: Choose and read image files from disk.
* Instead, use reader ops to get data into TensorFlow directly. * Image Processing: Decode image records into images, preprocess, and organize
* Parallelize data processing. into mini-batches.
* Use software pipelining to feed data, so that data is available immediately * CPU-to-GPU Data Transfer: Transfer images from CPU to GPU.
when needed.
One way to implement software pipelining in TensorFlow is through The dominant part of each stage is executed in parallel with the other stages
`tf.FifoQueue`, and it is possible to parallelize data processing through using `data_flow_ops.StagingArea`. `StagingArea` is a queue-like operator
`tf.train.queue_runner`, which uses Python threads as its underlying similar to @{tf.FIFOQueue}. The difference is that `StagingArea` offers simpler
implementation. functionality and can be executed on both CPU and GPU in parallel with other
This lays the foundation for the current Inception input pipeline. stages. Breaking the input pipeline into 3 stages that operate independently in
This design is well built for feeding older generation of GPUs, parallel is scalable and takes full advantage of large multi-core environments.
but the overhead of Python threads is too large to feed newer GPUs that are four The rest of this section details the stages followed by details about using
to five times faster. `data_flow_ops.StagingArea`.
In this model, we explore an alternative design that uses the native ### Parallelize I/O Reads
parallelism in TensorFlow. In our example of an image model input pipeline,
there are a few important parts:
* Choose and read the image files from the disk. `data_flow_ops.RecordInput` is used to parallelize reading from disk. Given a
* Decode the image data into images, transform and add distortion so they are list of input files representing TFRecords, `RecordInput` continuously reads
ready to be used. records using background threads. The records are placed into its own large
* Organize the transformed images into a minibatch. internal pool and when it has loaded at least half of its capacity, it produces
* Transfer the images from CPU to GPU, so they are ready for model training. output tensors.
It is important to note that the dominant part of each stage can happen in This op has its own internal threads that are dominated by I/O time that consume
parallel with that of other stages: minimal CPU, which allows it to run smoothly in parallel with the rest of the
the file IO uses DMA to transfer the data from hard disk to memory; model.
image decoding, transformation and distortion are CPU-heavy;
the data transfer from CPU to GPU uses the GPU's copy-engine unit;
and the GPU kernels use the main SMs of the GPU.
It is natural to cut our pipeline into those parts so they can run in parallel
with each other.
Also, as mentioned earlier, most of the current input pipeline heavily uses
Python threads. However, the large overhead introduced by Python threads
severely limits its scalability when the newer GPUs are a lot faster; we can
alleviate this by making a single `session.run` call execute all parts of the
pipeline.
### Parallelize IO Reads
In this new model, we use the native parallelism in TensorFlow: TensorFlow
subscribes to an eager-execution model, which means that when nodes in the graph
became available, TensorFlow will try to execute as many of them as possible.
In order to parallelize reading from hard disk, we use `data_flow_ops.RecordInput`
in this model.
Given a list of input files of TFRecords, `RecordInput` continuously reads
records using background threads, placing the records into its own large,
internal pool of records.
When it is has loaded at least half of its capacity, it produces output tensors.
Since this op has its internal threads, and is dominated by IO time that doesnt
consume much CPU time, it naturally runs in parallel with the rest of the model.
### Parallelize Image Processing ### Parallelize Image Processing
After reading from “RecordInput”, the tensors are passed to the input processing After images are read from `RecordInput` they are passed as tensors to the image
pipeline. For example, if we need to feed 8 GPUs, each with a batch-size of 32, processing pipeline. To make the image processing pipeline easier to explain,
then for each step we do the following. assume that the input pipeline is targeting 8 GPUs with a batch size of 256 (32
per GPU).
First, read 32x8=256 records, and process them individually, in 256 records are read and processed individually in parallel. This starts with
parallel. This starts with 256 independent RecordInput read ops in the graph. 256 independent `RecordInput` read ops in the graph. Each read op is followed by
an identical set of ops for image preprocessing that are considered independent
and executed in parallel. The image preprocessing ops include operations such as
image decoding, distortion, and resizing.
Then, follow each read with identical set of ops for processing. Each set is Once the images are through preprocessing, they are concatenated together into 8
considered independent and will execute in parallel. The operations include batch size 32 tensors. Rather than use @{tf.concat} for this purpose, which is
image decoding, image distortion, and resizing. implemented as a single op that waits for all the inputs to be ready before
concatenating them together, @{tf.parallel_stack} is used. @{tf.parallel_stack}
Finally, once the images are ready, they will be concatenated together into 8
batch-size 32 tensors.
Note that we can use “tf.concat” for this purpose.
However, “tf.concat” is implemented as a single op, which waits for all
the inputs to be ready, and then concatenates them together. Since all
inputs are produced in parallel, there will be a long tail waiting for all
inputs to be available; and when concatenation happens, the op becomes memory
limited as all input tensors compete for memory bandwidth.
So for the final concatenation, we use `tf.parallel_stack` instead. This
allocates an uninitialized tensor as an output, and each input tensor is written allocates an uninitialized tensor as an output, and each input tensor is written
to its designated portion of the output tensor as soon as the input is to its designated portion of the output tensor as soon as the input is
available. When all the input tensors are finished, the output tensor is passed available.
along in the graph. This effectively hides all the memory latency with the long
tail of producing all the input tensors. When all the input tensors are finished, the output tensor is passed along in
the graph. This effectively hides all the memory latency with the long tail of
producing all the input tensors.
### Parallelize CPU-to-GPU Data Transfer ### Parallelize CPU-to-GPU Data Transfer
In our example, once all the input images are processed and concatenated Continuing with the assumption that the target is 8 GPUs with a batch size of
together by the CPU, we have 8 tensors, each of which has a batch-size of 32. 256 (32 per GPU). Once the input images are processed and concatenated together
These tensors are then to be used by the GPU for the model training. by the CPU, we have 8 tensors each with a batch-size of 32.
In TensorFlow, users can use tensors from one device on any other device TensorFlow enables tensors from one device to be used on any other device
directly. TensorFlow inserts implicit copies to make the tensors available on directly. TensorFlow inserts implicit copies to make the tensors available on
any devices where they are used. The runtime schedules the copy between devices any devices where they are used. The runtime schedules the copy between devices
to run before the tensors are actually used. However, if the copy cannot finish to run before the tensors are actually used. However, if the copy cannot finish
in time, the computation that needs those tensors will stall. in time, the computation that needs those tensors will stall and result in
decreased performance.
For high-performance models, it is helpful to explicitly schedule the copy ahead In this implementation, `data_flow_ops.StagingArea` is used to explicitly
of the time in parallel, so when the computation starts on GPU, all the tensors schedule the copy in parallel. The end result is that when computation starts on
are already available on the right device. the GPU, all the tensors are already available.
### Software Pipelining ### Software Pipelining
With all the stages capable of being driven by different processors, we insert With all the stages capable of being driven by different processors,
`data_flow_ops.StagingArea` in between them so they run in parallel. `data_flow_ops.StagingArea` is used between them so they run in parallel.
`StagingArea` is a queue-like operator similar to `tf.FifoQueue`. `StagingArea` is a queue-like operator similar to @{tf.FIFOQueue} that offers
But it offers simpler functionalities and can be executed on both CPU and GPU. simpler functionalities that can be executed on both CPU and GPU.
Before the model starts running all the stages, we warm up the stages in order Before the model starts running all the stages, the input pipeline stages are
so the staging buffers in between all have one set of data in them. warmed up to prime the staging buffers in between with one set of data.
During each run step that follows, we will run all the stages. During each run step, one set of data is read from the staging buffers at
They read one set of data from the staging buffers at the beginning of each the beginning of each stage, and one set is pushed at the end.
stage, and push one set at end end.
For example: if there are three stages: A, B and C. For example: if there are three stages: A, B and C. There are two staging areas
There are two staging areas in between: S1 and S2. in between: S1 and S2. During the warm up, we run:
During the warm up, we run:
``` ```
Warm up: Warm up:
@ -162,123 +116,126 @@ Step 4: A3 B2 C1
Step 5: A4 B3 C2 Step 5: A4 B3 C2
``` ```
After the warm up, S1 and S2 each have one set of data in them. After the warm up, S1 and S2 each have one set of data in them. For each step of
For each step of the actual execution, one set of data is consumed from each the actual execution, one set of data is consumed from each staging area, and
staging area, and one set is added to each. one set is added to each.
There are a few nice properties about the scheme: Benefits of using this scheme:
* All the stages are non-blocking, since the staging areas always have one set * All stages are non-blocking, since the staging areas always have one set of
of data after the warm up. data after the warm up.
* Each stage can run in parallel since they can all start immediately. * Each stage can run in parallel since they can all start immediately.
* The staging buffers have a fixed memory overhead. They will have at most one * The staging buffers have a fixed memory overhead. They will have at most one
extra set of data. extra set of data.
* Only a single`session.run()` call is needed to run all stages of the step, * Only a single`session.run()` call is needed to run all stages of the step,
which makes profiling and debugging much easier. which makes profiling and debugging much easier.
## Best Practices in Building High-Performance Models ## Best Practices in Building High-Performance Models
The computation on GPU can happen immediately since the input data have already Collected below are a couple of additional best practices that can improve
been transferred onto GPU when the step starts. performance and increase the flexiblity of models.
But it is still important to build the model that runs as fast as possible.
Here are some tips for a high-performance convolutional neural network (CNN)
model:
### Build the model with both NHWC and NCHW ### Build the model with both NHWC and NCHW
Most TensorFlow operations used by a CNN support both NHWC and NCHW data format. Most TensorFlow operations used by a CNN support both NHWC and NCHW data format.
On GPU, NCHW is faster. On GPU, NCHW is faster. But on CPU, NHWC is sometimes faster.
But on CPU, NHWC is sometimes faster.
So it is a good idea to build the model that can work in both ways. Building a model to support both date formats keeps the model flexible and
Our model shows a good way to do that effectively. capable of operating optimally regardless of platform. Most TensorFlow
For GPU training, we should always use NCHW. operations used by a CNN support both NHWC and NCHW data format. The benchmark
But if the model needs inference on CPU, we could use NHWC; weights obtained script was written to support both NCHW and NHWC. NCHW should always be used
from training with NCHW data format can be used for inference in NHWC data when training with GPUs. NHWC is sometimes faster on CPU. A flexible model can
format. be trained on GPUs using NCHW with inference done on CPU using NHWC with the
weights obtained from training.
### Use Fused Batch-Normalization ### Use Fused Batch-Normalization
The default batch-normalization in TensorFlow is implemented as composite The default batch-normalization in TensorFlow is implemented as composite
operations. operations. This is very general, but often leads to suboptimal performance. An
This is very general, but often leads to suboptimal performance. alternative is to use fused batch-normalization which often has much better
An alternative is the fused batch-normalization, and the performance on GPU performance on GPU. Below is an example of using @{tf.contrib.layers.batch_norm}
is often much faster. to implement fused batch-normalization.
```python
bn = tf.contrib.layers.batch_norm(
input_layer, fused=True, data_format='NCHW'
scope=scope)
```
## Variable Distribution and Gradient Aggregation ## Variable Distribution and Gradient Aggregation
During training, training variable values are updated using aggregated gradients During training, training variable values are updated using aggregated gradients
and deltas. In this model, we demonstrate that with the flexible and and deltas. In the benchmark script, we demonstrate that with the flexible and
general-purpose TensorFlow primitives, it is fairly easy to build a diverse general-purpose TensorFlow primitives, a diverse range of high-performance
range of high-performance distribution and aggregation schemes for different distribution and aggregation schemes can be built.
types of systems.
For example: Three examples of variable distribution and aggregation were included in the
script:
* The standard parameter-server where each replica of the training model reads * `parameter_server` where each replica of the training model reads the
the variables directly, and updates the variable independently. When each variables from a parameter server and updates the variable independently.
model needs the variables, they are copied over through the standard implicit When each model needs the variables, they are copied over through the
copies added by the TensorFlow runtime. It is shown how to use this method standard implicit copies added by the TensorFlow runtime. The example
in either local training, distributed synchronous training, and distributed [script](https://github.com/tensorflow/benchmarks/tree/master/scripts/tf_cnn_benchmarks)
asynchronous training. illustrates using this method for local training, distributed synchronous
* A replicated mode for local training where each GPU has an identical training, and distributed asynchronous training.
copy of the training parameters. The forward and backward computation can * `replicated` places an identical copy of each training variable on each
start immediately as the variable data is immediately available. Gradients GPU. The forward and backward computation can start immediately as the
are accumulated across all GPUs, and the aggregated total is applied to variable data is immediately available. Gradients are accumulated across all
each GPU's copy of the variables so that they stay in sync. GPUs, and the aggregated total is applied to each GPU's copy of the
* A distributed replicated mode of training where each GPU has an identical copy variables to keep them in sync.
of the training parameters, and a master copy of the variables is stored * `distributed_replicated` places an identical copy of the training parameters
on the parameter-servers. The forward and backward computation can on each GPU along with a master copy on the parameter servers. The forward
start immediately as the variable data is immediately available. Gradients and backward computation can start immediately as the variable data is
are accumulated across all GPUs on each server and then the per-server immediately available. Gradients are accumulated across all GPUs on each
aggregated gradients are applied to the master copy. After all workers do server and then the per-server aggregated gradients are applied to the
this, each worker updates its copy of the variable from the master copy. master copy. After all workers do this, each worker updates its copy of the
variable from the master copy.
We show that most of the variable distribution and aggregation subsystem can Below are additional details about each approach.
be implemented through TensorFlow low-level primitives with manageable
complexity at the model level. Here we discuss some more details.
### Parameter-server Variables ### Parameter Server Variables
The most common way trainable variables are managed in TensorFlow models is the The most common way trainable variables are managed in TensorFlow models is
parameter server mode. parameter server mode.
In a distributed system, this means that each worker process runs the same In a distributed system, each worker process runs the same model, and parameter
model, and parameter server processes own the master copies of the variables. server processes own the master copies of the variables. When a worker needs a
When a worker needs a variable from a parameter server, it refers to it variable from a parameter server, it refers to it directly. The TensorFlow
directly. The TensorFlow runtime adds implicit copies to the graph to make the runtime adds implicit copies to the graph to make the variable value available
variable value available on the computation device that needs it. When a on the computation device that needs it. When a gradient is computed on a
gradient is computed on a worker, it is sent to the parameter server that owns worker, it is sent to the parameter server that owns the particular variable,
the particular variable, and the corresponding optimizer is used to update the and the corresponding optimizer is used to update the variable.
variable.
There are some techniques to improve throughput: There are some techniques to improve throughput:
* The variables are spread among parameter servers based on their size, for load * The variables are spread among parameter servers based on their size, for
balancing. load balancing.
* When each worker has multiple GPUs, gradients are accumulated across the GPUs * When each worker has multiple GPUs, gradients are accumulated across the
and a single aggregated gradient is sent to the parameter server. This reduces GPUs and a single aggregated gradient is sent to the parameter server. This
the network bandwidth and the amount of work done by the parameter servers. reduces the network bandwidth and the amount of work done by the parameter
servers.
For coordinating between workers, a very common mode is async updates, where For coordinating between workers, a very common mode is async updates, where
each worker updates the master copy of the variables without synchronizing with each worker updates the master copy of the variables without synchronizing with
other workers. In our model, we demonstrate that it is fairly easy to introduce other workers. In our model, we demonstrate that it is fairly easy to introduce
synchronization across workers so updates for all workers are finished in one synchronization across workers so updates for all workers are finished in one
step before the next step can start. step before the next step can start.
The parameter-server method can also be used for local training, In this case, The parameter server method can also be used for local training, In this case,
instead of spreading the master copies of variables across parameters servers, instead of spreading the master copies of variables across parameters servers,
they are either on the CPU or spread across the available GPUs. they are either on the CPU or spread across the available GPUs.
Due to the simple nature of this setup, this architecture has gained a lot of Due to the simple nature of this setup, this architecture has gained a lot of
popularity within the community. popularity within the community.
This is available in the benchmark scripts as the 'parameter_server' This mode can be used in the script by passing
variable_update mode. `--variable_update=parameter_server`.
![parameter_server mode in distributed <div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px;">
training](../images/perf_parameter_server_mode_doc.png){ <img style="width:100%" alt="parameter_server mode in distributed training"
width="900" style="max-width: inherit"} src="../images/perf_parameter_server_mode_doc.png">
</div>
### Replicated Variables ### Replicated Variables
@ -292,19 +249,18 @@ devices and the fully aggregated gradient is then applied to each local copy.
Gradient aggregation across the server can be done in different ways: Gradient aggregation across the server can be done in different ways:
* Using standard TensorFlow operations to accumulate the total on a single * Using standard TensorFlow operations to accumulate the total on a single
device (CPU or GPU) and then copy it back to all GPUs. device (CPU or GPU) and then copy it back to all GPUs.
* Using NVIDIA NCCL, described below in the NCCL section. * Using NVIDIA® NCCL, described below in the NCCL section.
This is available in the benchmark scripts for local execution only, as the This mode can be used in the script by passing `--variable_update=replicated`.
'replicated' variable_update mode.
### Replicated Variables in Distributed Training ### Replicated Variables in Distributed Training
The replicated method for variables can be extended to distributed training. The replicated method for variables can be extended to distributed training. One
One way to do this like the replicated mode: aggregate the gradients fully way to do this like the replicated mode: aggregate the gradients fully across
across the cluster and apply them to each local copy of the variable. This may the cluster and apply them to each local copy of the variable. This may be shown
be shown in a future version of this scripts; the scripts do present a different in a future version of this scripts; the scripts do present a different
variation, described here. variation, described here.
In this mode, in addition to each GPU's copy of the variables, a master copy is In this mode, in addition to each GPU's copy of the variables, a master copy is
@ -314,28 +270,30 @@ immediately using the local copies of the variables.
As the gradients of the weights become available, they are sent back to the As the gradients of the weights become available, they are sent back to the
parameter servers and all local copies are updated: parameter servers and all local copies are updated:
1. All the gradients from the GPU on the same worker are aggregated together. 1. All the gradients from the GPU on the same worker are aggregated together.
2. Aggregated gradients from each worker are sent to the parameter server that 2. Aggregated gradients from each worker are sent to the parameter server that
owns the variable, where the specified optimizer is used to update the owns the variable, where the specified optimizer is used to update the
master copy of the variable. master copy of the variable.
3. Each worker updates its local copy of the variable from the master. In 3. Each worker updates its local copy of the variable from the master. In the
the example model, this is done with a cross-replica barrier that waits for example model, this is done with a cross-replica barrier that waits for all
all the workers to finish updating the variables, and fetches the new the workers to finish updating the variables, and fetches the new variable
variable only after the barrier has been released by all replicas. Once the only after the barrier has been released by all replicas. Once the copy
copy finishes for all variables, this marks the end of a training step, and a finishes for all variables, this marks the end of a training step, and a new
new step can start. step can start.
Although this sounds similar to the standard use of parameter servers, the Although this sounds similar to the standard use of parameter servers, the
performance is often better in many cases. This is largely due to the fact the performance is often better in many cases. This is largely due to the fact the
computation can happen without any delay, and much of the copy latency of early computation can happen without any delay, and much of the copy latency of early
gradients can be hidden by later computation layers. gradients can be hidden by later computation layers.
This is available in the benchmark scripts as the 'distributed_replicated' This mode can be used in the script by passing
variable_update mode. `--variable_update=distributed_replicated`.
![distributed_replicated mode](
../images/perf_distributed_replicated_mode_doc.png){ <div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px;">
width="900" style="max-width: inherit"} <img style="width:100%" alt="distributed_replicated mode"
src="../images/perf_distributed_replicated_mode_doc.png">
</div>
#### NCCL #### NCCL
@ -343,47 +301,29 @@ In order to broadcast variables and aggregate gradients across different GPUs
within the same host machine, we can use the default TensorFlow implicit copy within the same host machine, we can use the default TensorFlow implicit copy
mechanism. mechanism.
However, we can instead use the optional NCCL support. NCCL is an NVIDIA However, we can instead use the optional NCCL (@{tf.contrib.nccl}) support. NCCL
library that can efficiently broadcast and aggregate data across different GPUs. is an NVIDIA® library that can efficiently broadcast and aggregate data across
It schedules a cooperating kernel on each GPU that knows how to best utilize the different GPUs. It schedules a cooperating kernel on each GPU that knows how to
underlying hardware topology; this kernel uses a single SM of the GPU. best utilize the underlying hardware topology; this kernel uses a single SM of
the GPU.
In our experiment, we demonstrate that although NCCL often leads to much faster In our experiment, we demonstrate that although NCCL often leads to much faster
data aggregation by itself, it doesn't necessarily lead to faster training. Our data aggregation by itself, it doesn't necessarily lead to faster training. Our
hypothesis is that the implicit copies are essentially free since they go to the hypothesis is that the implicit copies are essentially free since they go to the
copy engine on GPU, as long as its latency can be hidden by the main computation copy engine on GPU, as long as its latency can be hidden by the main computation
itself. Although NCCL can transfer data faster, it takes one SM away, and adds itself. Although NCCL can transfer data faster, it takes one SM away, and adds
more pressure to the underlying L2 cache. Our results show that for 8-GPUs, more pressure to the underlying L2 cache. Our results show that for 8-GPUs, NCCL
NCCL often leads to better performance. However, for fewer GPUs, the implicit often leads to better performance. However, for fewer GPUs, the implicit copies
copies often perform better. often perform better.
#### Staged Variables #### Staged Variables
We further introduce a staged-variable mode where we use staging areas for both We further introduce a staged-variable mode where we use staging areas for both
the variable reads, and their updates. the variable reads, and their updates. Similar to software pipelining of the
Similar to software pipelining of the input pipeline, this can hide the data input pipeline, this can hide the data copy latency. If the computation time
copy latency. takes longer than the copy and aggregation, the copy itself becomes essentially
If the computation time takes longer than the copy and aggregation, the copy free.
itself becomes essentially free.
The downside is that all the weights read are from the previous training step. The downside is that all the weights read are from the previous training step.
So it is a different algorithm from SGD. So it is a different algorithm from SGD. But it is possible to improve its
But it is possible to improve its convergence by adjusting learning rate and convergence by adjusting learning rate and other hyperparameters.
other hyperparameters.
## Conclusions
In this high-performance model, we present a number of options to build
high-performance models in TensorFlow.
Due to the flexible design in TensorFlow, advanced features like this often
requires no system-level changes, and can be largely achieved through
model-level changes.
We do not claim which combination works best for a particular model.
That should be left to the engineers who build the model and the training system.
Many of the ingredients of the high-performance model will find their ways
to high-level primitives that become transparent to users.
However, we have shown that advanced users can easily tune and modify the
underlying model behavior using low-level primitives.
This could be very useful when improving performance for particular system
setups and model configurations.

View File

@ -94,13 +94,15 @@ class Estimator(object):
* Args: * Args:
* `features`: single `Tensor` or `dict` of `Tensor`s * `features`: This is the first item returned from the `input_fn`
(depending on data passed to `train`), passed to `train`, 'evaluate`, and `predict`. This should be a
* `labels`: `Tensor` or `dict` of `Tensor`s (for multi-head single `Tensor` or `dict` of same.
models). If mode is `ModeKeys.PREDICT`, `labels=None` will be * `labels`: This is the second item returned from the `input_fn`
passed. If the `model_fn`'s signature does not accept passed to `train`, 'evaluate`, and `predict`. This should be a
`mode`, the `model_fn` must still be able to handle single `Tensor` or `dict` of same (for multi-head models). If
`labels=None`. mode is `ModeKeys.PREDICT`, `labels=None` will be passed. If
the `model_fn`'s signature does not accept `mode`, the
`model_fn` must still be able to handle `labels=None`.
* `mode`: Optional. Specifies if this training, evaluation or * `mode`: Optional. Specifies if this training, evaluation or
prediction. See `ModeKeys`. prediction. See `ModeKeys`.
* `params`: Optional `dict` of hyperparameters. Will receive what * `params`: Optional `dict` of hyperparameters. Will receive what

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Helper classes for tensor shape inference.""" """Helper classes for tensor shape inference."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
@ -31,8 +30,8 @@ class Dimension(object):
self._value = None self._value = None
else: else:
self._value = int(value) self._value = int(value)
if (not isinstance(value, compat.bytes_or_text_types) if (not isinstance(value, compat.bytes_or_text_types) and
and self._value != value): self._value != value):
raise ValueError("Ambiguous dimension: %s" % value) raise ValueError("Ambiguous dimension: %s" % value)
if self._value < 0: if self._value < 0:
raise ValueError("Dimension %d must be >= 0" % self._value) raise ValueError("Dimension %d must be >= 0" % self._value)
@ -89,9 +88,8 @@ class Dimension(object):
True if this Dimension and `other` are compatible. True if this Dimension and `other` are compatible.
""" """
other = as_dimension(other) other = as_dimension(other)
return (self._value is None return (self._value is None or other.value is None or
or other.value is None self._value == other.value)
or self._value == other.value)
def assert_is_compatible_with(self, other): def assert_is_compatible_with(self, other):
"""Raises an exception if `other` is not compatible with this Dimension. """Raises an exception if `other` is not compatible with this Dimension.
@ -104,8 +102,8 @@ class Dimension(object):
is_compatible_with). is_compatible_with).
""" """
if not self.is_compatible_with(other): if not self.is_compatible_with(other):
raise ValueError("Dimensions %s and %s are not compatible" raise ValueError("Dimensions %s and %s are not compatible" % (self,
% (self, other)) other))
def merge_with(self, other): def merge_with(self, other):
"""Returns a Dimension that combines the information in `self` and `other`. """Returns a Dimension that combines the information in `self` and `other`.
@ -385,18 +383,17 @@ class TensorShape(object):
`Tensor`. It may be one of the following: `Tensor`. It may be one of the following:
* *Fully-known shape:* has a known number of dimensions and a known size * *Fully-known shape:* has a known number of dimensions and a known size
for each dimension. for each dimension. e.g. `TensorShape([16, 256])`
* *Partially-known shape:* has a known number of dimensions, and an unknown * *Partially-known shape:* has a known number of dimensions, and an unknown
size for one or more dimension. size for one or more dimension. e.g. `TensorShape([None, 256])`
* *Unknown shape:* has an unknown number of dimensions, and an unknown * *Unknown shape:* has an unknown number of dimensions, and an unknown
size in all dimensions. size in all dimensions. e.g. `TensorShape(None)`
If a tensor is produced by an operation of type `"Foo"`, its shape If a tensor is produced by an operation of type `"Foo"`, its shape
may be inferred if there is a registered shape function for may be inferred if there is a registered shape function for
`"Foo"`. See @{$adding_an_op#shape-functions-in-c$`Shape functions in C++`} for `"Foo"`. See @{$adding_an_op#shape-functions-in-c$`Shape functions in C++`}
details of shape functions and how to register them. Alternatively, for details of shape functions and how to register them. Alternatively,
the shape may be set explicitly using the shape may be set explicitly using @{tf.Tensor.set_shape}.
@{tf.Tensor.set_shape}.
""" """
def __init__(self, dims): def __init__(self, dims):
@ -414,7 +411,7 @@ class TensorShape(object):
self._dims = None self._dims = None
elif isinstance(dims, compat.bytes_or_text_types): elif isinstance(dims, compat.bytes_or_text_types):
raise TypeError("A string has ambiguous TensorShape, please wrap in a " raise TypeError("A string has ambiguous TensorShape, please wrap in a "
"list or convert to an int: %s" % dims) "list or convert to an int: %s" % dims)
elif isinstance(dims, tensor_shape_pb2.TensorShapeProto): elif isinstance(dims, tensor_shape_pb2.TensorShapeProto):
if dims.unknown_rank: if dims.unknown_rank:
self._dims = None self._dims = None
@ -422,7 +419,8 @@ class TensorShape(object):
self._dims = [ self._dims = [
# Protos store variable-size dimensions as -1 # Protos store variable-size dimensions as -1
as_dimension(dim.size if dim.size != -1 else None) as_dimension(dim.size if dim.size != -1 else None)
for dim in dims.dim] for dim in dims.dim
]
elif isinstance(dims, TensorShape): elif isinstance(dims, TensorShape):
self._dims = dims.dims self._dims = dims.dims
else: else:
@ -519,7 +517,7 @@ class TensorShape(object):
# suffixes of otherwise unknown shapes. # suffixes of otherwise unknown shapes.
return unknown_shape() return unknown_shape()
else: else:
return unknown_shape(ndims=stop-start) return unknown_shape(ndims=stop - start)
else: else:
return Dimension(None) return Dimension(None)
@ -560,8 +558,7 @@ class TensorShape(object):
new_dims.append(dim.merge_with(other[i])) new_dims.append(dim.merge_with(other[i]))
return TensorShape(new_dims) return TensorShape(new_dims)
except ValueError: except ValueError:
raise ValueError("Shapes %s and %s are not compatible" % raise ValueError("Shapes %s and %s are not compatible" % (self, other))
(self, other))
def concatenate(self, other): def concatenate(self, other):
"""Returns the concatenation of the dimension in `self` and `other`. """Returns the concatenation of the dimension in `self` and `other`.
@ -599,8 +596,8 @@ class TensorShape(object):
other = as_shape(other) other = as_shape(other)
if self.ndims is not None and other.ndims is not None: if self.ndims is not None and other.ndims is not None:
if self.ndims != other.ndims: if self.ndims != other.ndims:
raise ValueError( raise ValueError("Shapes %s and %s must have the same rank" % (self,
"Shapes %s and %s must have the same rank" % (self, other)) other))
def assert_has_rank(self, rank): def assert_has_rank(self, rank):
"""Raises an exception if `self` is not compatible with the given `rank`. """Raises an exception if `self` is not compatible with the given `rank`.
@ -736,8 +733,8 @@ class TensorShape(object):
def is_fully_defined(self): def is_fully_defined(self):
"""Returns True iff `self` is fully defined in every dimension.""" """Returns True iff `self` is fully defined in every dimension."""
return (self._dims is not None return (self._dims is not None and all(dim.value is not None
and all(dim.value is not None for dim in self._dims)) for dim in self._dims))
def assert_is_fully_defined(self): def assert_is_fully_defined(self):
"""Raises an exception if `self` is not fully defined in every dimension. """Raises an exception if `self` is not fully defined in every dimension.
@ -767,9 +764,10 @@ class TensorShape(object):
return tensor_shape_pb2.TensorShapeProto(unknown_rank=True) return tensor_shape_pb2.TensorShapeProto(unknown_rank=True)
else: else:
return tensor_shape_pb2.TensorShapeProto(dim=[ return tensor_shape_pb2.TensorShapeProto(dim=[
tensor_shape_pb2.TensorShapeProto.Dim( tensor_shape_pb2.TensorShapeProto.Dim(size=-1
size=-1 if d.value is None else d.value) if d.value is None else d.value)
for d in self._dims]) for d in self._dims
])
def __eq__(self, other): def __eq__(self, other):
"""Returns True if `self` is equivalent to `other`.""" """Returns True if `self` is equivalent to `other`."""

View File

@ -41,6 +41,180 @@ cuda_py_test(
], ],
) )
cuda_py_test(
name = "beta_test",
size = "small",
srcs = ["beta_test.py"],
additional_deps = [
"//tensorflow/python/ops/distributions",
"//third_party/py/numpy",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn_ops",
"//tensorflow/python:platform_test",
],
)
cuda_py_test(
name = "bernoulli_test",
size = "small",
srcs = ["bernoulli_test.py"],
additional_deps = [
"//tensorflow/python/ops/distributions",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
)
cuda_py_test(
name = "categorical_test",
size = "small",
srcs = ["categorical_test.py"],
additional_deps = [
"//tensorflow/python/ops/distributions",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:random_ops",
],
)
cuda_py_test(
name = "dirichlet_test",
size = "small",
srcs = ["dirichlet_test.py"],
additional_deps = [
"//tensorflow/python/ops/distributions",
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
cuda_py_test(
name = "dirichlet_multinomial_test",
size = "medium",
srcs = ["dirichlet_multinomial_test.py"],
additional_deps = [
"//tensorflow/python/ops/distributions",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
)
cuda_py_test(
name = "exponential_test",
srcs = ["exponential_test.py"],
additional_deps = [
"//tensorflow/python/ops/distributions",
"//third_party/py/numpy",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:nn_ops",
"//tensorflow/python:platform_test",
],
)
cuda_py_test(
name = "gamma_test",
srcs = ["gamma_test.py"],
additional_deps = [
"//tensorflow/python/ops/distributions",
"//third_party/py/numpy",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:nn_ops",
"//tensorflow/python:platform_test",
],
)
cuda_py_test(
name = "laplace_test",
srcs = ["laplace_test.py"],
additional_deps = [
"//tensorflow/python/ops/distributions",
"//third_party/py/numpy",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:nn_ops",
"//tensorflow/python:platform_test",
],
)
cuda_py_test(
name = "multinomial_test",
srcs = ["multinomial_test.py"],
additional_deps = [
"//tensorflow/python/ops/distributions",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
)
cuda_py_test(
name = "student_t_test",
size = "small",
srcs = ["student_t_test.py"],
additional_deps = [
"//tensorflow/python/ops/distributions",
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn_ops",
"//tensorflow/python:platform_test",
],
tags = ["nomsan"], # disable to avoid false positives from scipy.
)
cuda_py_test(
name = "uniform_test",
size = "small",
srcs = ["uniform_test.py"],
additional_deps = [
"//tensorflow/python/ops/distributions",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
],
)
cuda_py_test( cuda_py_test(
name = "normal_test", name = "normal_test",
size = "medium", size = "medium",

View File

@ -18,15 +18,30 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import importlib
import numpy as np import numpy as np
import scipy.special
from tensorflow.contrib.distributions.python.ops import bernoulli
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import bernoulli
from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
def try_import(name): # pylint: disable=invalid-name
module = None
try:
module = importlib.import_module(name)
except ImportError as e:
tf_logging.warning("Could not import %s: %s" % (name, str(e)))
return module
special = try_import("scipy.special")
def make_bernoulli(batch_shape, dtype=dtypes.int32): def make_bernoulli(batch_shape, dtype=dtypes.int32):
@ -54,13 +69,16 @@ class BernoulliTest(test.TestCase):
with self.test_session(): with self.test_session():
self.assertAllClose(logits, dist.logits.eval()) self.assertAllClose(logits, dist.logits.eval())
if not special:
return
with self.test_session(): with self.test_session():
self.assertAllClose(scipy.special.expit(logits), dist.probs.eval()) self.assertAllClose(special.expit(logits), dist.probs.eval())
p = [0.01, 0.99, 0.42] p = [0.01, 0.99, 0.42]
dist = bernoulli.Bernoulli(probs=p) dist = bernoulli.Bernoulli(probs=p)
with self.test_session(): with self.test_session():
self.assertAllClose(scipy.special.logit(p), dist.logits.eval()) self.assertAllClose(special.logit(p), dist.logits.eval())
def testInvalidP(self): def testInvalidP(self):
invalid_ps = [1.01, 2.] invalid_ps = [1.01, 2.]
@ -160,7 +178,9 @@ class BernoulliTest(test.TestCase):
def testPmfWithP(self): def testPmfWithP(self):
p = [[0.2, 0.4], [0.3, 0.6]] p = [[0.2, 0.4], [0.3, 0.6]]
self._testPmf(probs=p) self._testPmf(probs=p)
self._testPmf(logits=scipy.special.logit(p)) if not special:
return
self._testPmf(logits=special.logit(p))
def testBroadcasting(self): def testBroadcasting(self):
with self.test_session(): with self.test_session():

View File

@ -16,18 +16,33 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import importlib
import numpy as np import numpy as np
from scipy import special
from scipy import stats
from tensorflow.contrib.distributions.python.ops import beta as beta_lib
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import random_seed from tensorflow.python.framework import random_seed
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops from tensorflow.python.ops import nn_ops
from tensorflow.python.ops.distributions import beta as beta_lib
from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
def try_import(name): # pylint: disable=invalid-name
module = None
try:
module = importlib.import_module(name)
except ImportError as e:
tf_logging.warning("Could not import %s: %s" % (name, str(e)))
return module
special = try_import("scipy.special")
stats = try_import("scipy.stats")
class BetaTest(test.TestCase): class BetaTest(test.TestCase):
@ -167,18 +182,22 @@ class BetaTest(test.TestCase):
with session.Session(): with session.Session():
a = [1., 2, 3] a = [1., 2, 3]
b = [2., 4, 1.2] b = [2., 4, 1.2]
expected_mean = stats.beta.mean(a, b)
dist = beta_lib.Beta(a, b) dist = beta_lib.Beta(a, b)
self.assertEqual(dist.mean().get_shape(), (3,)) self.assertEqual(dist.mean().get_shape(), (3,))
if not stats:
return
expected_mean = stats.beta.mean(a, b)
self.assertAllClose(expected_mean, dist.mean().eval()) self.assertAllClose(expected_mean, dist.mean().eval())
def testBetaVariance(self): def testBetaVariance(self):
with session.Session(): with session.Session():
a = [1., 2, 3] a = [1., 2, 3]
b = [2., 4, 1.2] b = [2., 4, 1.2]
expected_variance = stats.beta.var(a, b)
dist = beta_lib.Beta(a, b) dist = beta_lib.Beta(a, b)
self.assertEqual(dist.variance().get_shape(), (3,)) self.assertEqual(dist.variance().get_shape(), (3,))
if not stats:
return
expected_variance = stats.beta.var(a, b)
self.assertAllClose(expected_variance, dist.variance().eval()) self.assertAllClose(expected_variance, dist.variance().eval())
def testBetaMode(self): def testBetaMode(self):
@ -228,9 +247,11 @@ class BetaTest(test.TestCase):
with session.Session(): with session.Session():
a = [1., 2, 3] a = [1., 2, 3]
b = [2., 4, 1.2] b = [2., 4, 1.2]
expected_entropy = stats.beta.entropy(a, b)
dist = beta_lib.Beta(a, b) dist = beta_lib.Beta(a, b)
self.assertEqual(dist.entropy().get_shape(), (3,)) self.assertEqual(dist.entropy().get_shape(), (3,))
if not stats:
return
expected_entropy = stats.beta.entropy(a, b)
self.assertAllClose(expected_entropy, dist.entropy().eval()) self.assertAllClose(expected_entropy, dist.entropy().eval())
def testBetaSample(self): def testBetaSample(self):
@ -243,6 +264,8 @@ class BetaTest(test.TestCase):
sample_values = samples.eval() sample_values = samples.eval()
self.assertEqual(sample_values.shape, (100000,)) self.assertEqual(sample_values.shape, (100000,))
self.assertFalse(np.any(sample_values < 0.0)) self.assertFalse(np.any(sample_values < 0.0))
if not stats:
return
self.assertLess( self.assertLess(
stats.kstest( stats.kstest(
# Beta is a univariate distribution. # Beta is a univariate distribution.
@ -286,6 +309,8 @@ class BetaTest(test.TestCase):
sample_values = samples.eval() sample_values = samples.eval()
self.assertEqual(sample_values.shape, (100000, 3, 2, 2)) self.assertEqual(sample_values.shape, (100000, 3, 2, 2))
self.assertFalse(np.any(sample_values < 0.0)) self.assertFalse(np.any(sample_values < 0.0))
if not stats:
return
self.assertAllClose( self.assertAllClose(
sample_values[:, 1, :].mean(axis=0), sample_values[:, 1, :].mean(axis=0),
stats.beta.mean(a, b)[1, :], stats.beta.mean(a, b)[1, :],
@ -301,6 +326,8 @@ class BetaTest(test.TestCase):
actual = beta_lib.Beta(a, b).cdf(x).eval() actual = beta_lib.Beta(a, b).cdf(x).eval()
self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x) self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x)
self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x) self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x)
if not stats:
return
self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0) self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0)
def testBetaLogCdf(self): def testBetaLogCdf(self):
@ -313,6 +340,8 @@ class BetaTest(test.TestCase):
actual = math_ops.exp(beta_lib.Beta(a, b).log_cdf(x)).eval() actual = math_ops.exp(beta_lib.Beta(a, b).log_cdf(x)).eval()
self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x) self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x)
self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x) self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x)
if not stats:
return
self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0) self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0)
def testBetaWithSoftplusConcentration(self): def testBetaWithSoftplusConcentration(self):
@ -342,6 +371,8 @@ class BetaTest(test.TestCase):
d2_sp = beta_lib.BetaWithSoftplusConcentration(concentration1=a2_sp, d2_sp = beta_lib.BetaWithSoftplusConcentration(concentration1=a2_sp,
concentration0=b2_sp) concentration0=b2_sp)
if not special:
return
kl_expected = (special.betaln(a2, b2) - special.betaln(a1, b1) + kl_expected = (special.betaln(a2, b2) - special.betaln(a1, b1) +
(a1 - a2) * special.digamma(a1) + (a1 - a2) * special.digamma(a1) +
(b1 - b2) * special.digamma(b1) + (b1 - b2) * special.digamma(b1) +

View File

@ -20,7 +20,6 @@ from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.contrib.distributions.python.ops import categorical
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_util from tensorflow.python.framework import tensor_util
@ -29,6 +28,7 @@ from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import categorical
from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.platform import test from tensorflow.python.platform import test

View File

@ -17,14 +17,15 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.contrib import distributions
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import dirichlet_multinomial
from tensorflow.python.platform import test from tensorflow.python.platform import test
ds = distributions
ds = dirichlet_multinomial
class DirichletMultinomialTest(test.TestCase): class DirichletMultinomialTest(test.TestCase):

View File

@ -16,14 +16,29 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import importlib
import numpy as np import numpy as np
from scipy import stats
from tensorflow.contrib.distributions.python.ops import dirichlet as dirichlet_lib
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import dirichlet as dirichlet_lib
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
def try_import(name): # pylint: disable=invalid-name
module = None
try:
module = importlib.import_module(name)
except ImportError as e:
tf_logging.warning("Could not import %s: %s" % (name, str(e)))
return module
stats = try_import("scipy.stats")
class DirichletTest(test.TestCase): class DirichletTest(test.TestCase):
@ -132,9 +147,11 @@ class DirichletTest(test.TestCase):
def testMean(self): def testMean(self):
with self.test_session(): with self.test_session():
alpha = [1., 2, 3] alpha = [1., 2, 3]
expected_mean = stats.dirichlet.mean(alpha)
dirichlet = dirichlet_lib.Dirichlet(concentration=alpha) dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
self.assertEqual(dirichlet.mean().get_shape(), [3]) self.assertEqual(dirichlet.mean().get_shape(), [3])
if not stats:
return
expected_mean = stats.dirichlet.mean(alpha)
self.assertAllClose(dirichlet.mean().eval(), expected_mean) self.assertAllClose(dirichlet.mean().eval(), expected_mean)
def testCovarianceFromSampling(self): def testCovarianceFromSampling(self):
@ -177,11 +194,13 @@ class DirichletTest(test.TestCase):
with self.test_session(): with self.test_session():
alpha = [1., 2, 3] alpha = [1., 2, 3]
denominator = np.sum(alpha)**2 * (np.sum(alpha) + 1) denominator = np.sum(alpha)**2 * (np.sum(alpha) + 1)
dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
self.assertEqual(dirichlet.covariance().get_shape(), (3, 3))
if not stats:
return
expected_covariance = np.diag(stats.dirichlet.var(alpha)) expected_covariance = np.diag(stats.dirichlet.var(alpha))
expected_covariance += [[0., -2, -3], [-2, 0, -6], expected_covariance += [[0., -2, -3], [-2, 0, -6],
[-3, -6, 0]] / denominator [-3, -6, 0]] / denominator
dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
self.assertEqual(dirichlet.covariance().get_shape(), (3, 3))
self.assertAllClose(dirichlet.covariance().eval(), expected_covariance) self.assertAllClose(dirichlet.covariance().eval(), expected_covariance)
def testMode(self): def testMode(self):
@ -213,9 +232,11 @@ class DirichletTest(test.TestCase):
def testEntropy(self): def testEntropy(self):
with self.test_session(): with self.test_session():
alpha = [1., 2, 3] alpha = [1., 2, 3]
expected_entropy = stats.dirichlet.entropy(alpha)
dirichlet = dirichlet_lib.Dirichlet(concentration=alpha) dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
self.assertEqual(dirichlet.entropy().get_shape(), ()) self.assertEqual(dirichlet.entropy().get_shape(), ())
if not stats:
return
expected_entropy = stats.dirichlet.entropy(alpha)
self.assertAllClose(dirichlet.entropy().eval(), expected_entropy) self.assertAllClose(dirichlet.entropy().eval(), expected_entropy)
def testSample(self): def testSample(self):
@ -227,6 +248,8 @@ class DirichletTest(test.TestCase):
sample_values = samples.eval() sample_values = samples.eval()
self.assertEqual(sample_values.shape, (100000, 2)) self.assertEqual(sample_values.shape, (100000, 2))
self.assertTrue(np.all(sample_values > 0.0)) self.assertTrue(np.all(sample_values > 0.0))
if not stats:
return
self.assertLess( self.assertLess(
stats.kstest( stats.kstest(
# Beta is a univariate distribution. # Beta is a univariate distribution.

View File

@ -18,13 +18,28 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import importlib
import numpy as np import numpy as np
from scipy import stats
from tensorflow.contrib.distributions.python.ops import exponential as exponential_lib
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.ops import nn_ops from tensorflow.python.ops import nn_ops
from tensorflow.python.ops.distributions import exponential as exponential_lib
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
def try_import(name): # pylint: disable=invalid-name
module = None
try:
module = importlib.import_module(name)
except ImportError as e:
tf_logging.warning("Could not import %s: %s" % (name, str(e)))
return module
stats = try_import("scipy.stats")
class ExponentialTest(test.TestCase): class ExponentialTest(test.TestCase):
@ -36,14 +51,17 @@ class ExponentialTest(test.TestCase):
lam_v = 2.0 lam_v = 2.0
x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
exponential = exponential_lib.Exponential(rate=lam) exponential = exponential_lib.Exponential(rate=lam)
expected_log_pdf = stats.expon.logpdf(x, scale=1 / lam_v)
log_pdf = exponential.log_prob(x) log_pdf = exponential.log_prob(x)
self.assertEqual(log_pdf.get_shape(), (6,)) self.assertEqual(log_pdf.get_shape(), (6,))
self.assertAllClose(log_pdf.eval(), expected_log_pdf)
pdf = exponential.prob(x) pdf = exponential.prob(x)
self.assertEqual(pdf.get_shape(), (6,)) self.assertEqual(pdf.get_shape(), (6,))
if not stats:
return
expected_log_pdf = stats.expon.logpdf(x, scale=1 / lam_v)
self.assertAllClose(log_pdf.eval(), expected_log_pdf)
self.assertAllClose(pdf.eval(), np.exp(expected_log_pdf)) self.assertAllClose(pdf.eval(), np.exp(expected_log_pdf))
def testExponentialCDF(self): def testExponentialCDF(self):
@ -54,34 +72,43 @@ class ExponentialTest(test.TestCase):
x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
exponential = exponential_lib.Exponential(rate=lam) exponential = exponential_lib.Exponential(rate=lam)
expected_cdf = stats.expon.cdf(x, scale=1 / lam_v)
cdf = exponential.cdf(x) cdf = exponential.cdf(x)
self.assertEqual(cdf.get_shape(), (6,)) self.assertEqual(cdf.get_shape(), (6,))
if not stats:
return
expected_cdf = stats.expon.cdf(x, scale=1 / lam_v)
self.assertAllClose(cdf.eval(), expected_cdf) self.assertAllClose(cdf.eval(), expected_cdf)
def testExponentialMean(self): def testExponentialMean(self):
with session.Session(): with session.Session():
lam_v = np.array([1.0, 4.0, 2.5]) lam_v = np.array([1.0, 4.0, 2.5])
expected_mean = stats.expon.mean(scale=1 / lam_v)
exponential = exponential_lib.Exponential(rate=lam_v) exponential = exponential_lib.Exponential(rate=lam_v)
self.assertEqual(exponential.mean().get_shape(), (3,)) self.assertEqual(exponential.mean().get_shape(), (3,))
if not stats:
return
expected_mean = stats.expon.mean(scale=1 / lam_v)
self.assertAllClose(exponential.mean().eval(), expected_mean) self.assertAllClose(exponential.mean().eval(), expected_mean)
def testExponentialVariance(self): def testExponentialVariance(self):
with session.Session(): with session.Session():
lam_v = np.array([1.0, 4.0, 2.5]) lam_v = np.array([1.0, 4.0, 2.5])
expected_variance = stats.expon.var(scale=1 / lam_v)
exponential = exponential_lib.Exponential(rate=lam_v) exponential = exponential_lib.Exponential(rate=lam_v)
self.assertEqual(exponential.variance().get_shape(), (3,)) self.assertEqual(exponential.variance().get_shape(), (3,))
if not stats:
return
expected_variance = stats.expon.var(scale=1 / lam_v)
self.assertAllClose(exponential.variance().eval(), expected_variance) self.assertAllClose(exponential.variance().eval(), expected_variance)
def testExponentialEntropy(self): def testExponentialEntropy(self):
with session.Session(): with session.Session():
lam_v = np.array([1.0, 4.0, 2.5]) lam_v = np.array([1.0, 4.0, 2.5])
expected_entropy = stats.expon.entropy(scale=1 / lam_v)
exponential = exponential_lib.Exponential(rate=lam_v) exponential = exponential_lib.Exponential(rate=lam_v)
self.assertEqual(exponential.entropy().get_shape(), (3,)) self.assertEqual(exponential.entropy().get_shape(), (3,))
if not stats:
return
expected_entropy = stats.expon.entropy(scale=1 / lam_v)
self.assertAllClose(exponential.entropy().eval(), expected_entropy) self.assertAllClose(exponential.entropy().eval(), expected_entropy)
def testExponentialSample(self): def testExponentialSample(self):
@ -95,6 +122,8 @@ class ExponentialTest(test.TestCase):
sample_values = samples.eval() sample_values = samples.eval()
self.assertEqual(sample_values.shape, (100000, 2)) self.assertEqual(sample_values.shape, (100000, 2))
self.assertFalse(np.any(sample_values < 0.0)) self.assertFalse(np.any(sample_values < 0.0))
if not stats:
return
for i in range(2): for i in range(2):
self.assertLess( self.assertLess(
stats.kstest( stats.kstest(
@ -116,6 +145,8 @@ class ExponentialTest(test.TestCase):
sample_values = samples.eval() sample_values = samples.eval()
self.assertFalse(np.any(sample_values < 0.0)) self.assertFalse(np.any(sample_values < 0.0))
if not stats:
return
for i in range(2): for i in range(2):
self.assertLess( self.assertLess(
stats.kstest( stats.kstest(

View File

@ -17,18 +17,32 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import numpy as np import importlib
from scipy import special
from scipy import stats import numpy as np
from tensorflow.contrib.distributions.python.ops import gamma as gamma_lib
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops from tensorflow.python.ops import nn_ops
from tensorflow.python.ops.distributions import gamma as gamma_lib
from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
def try_import(name): # pylint: disable=invalid-name
module = None
try:
module = importlib.import_module(name)
except ImportError as e:
tf_logging.warning("Could not import %s: %s" % (name, str(e)))
return module
special = try_import("scipy.special")
stats = try_import("scipy.stats")
class GammaTest(test.TestCase): class GammaTest(test.TestCase):
@ -53,13 +67,14 @@ class GammaTest(test.TestCase):
beta_v = 3.0 beta_v = 3.0
x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
log_pdf = gamma.log_prob(x) log_pdf = gamma.log_prob(x)
self.assertEqual(log_pdf.get_shape(), (6,)) self.assertEqual(log_pdf.get_shape(), (6,))
self.assertAllClose(log_pdf.eval(), expected_log_pdf)
pdf = gamma.prob(x) pdf = gamma.prob(x)
self.assertEqual(pdf.get_shape(), (6,)) self.assertEqual(pdf.get_shape(), (6,))
if not stats:
return
expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
self.assertAllClose(log_pdf.eval(), expected_log_pdf)
self.assertAllClose(pdf.eval(), np.exp(expected_log_pdf)) self.assertAllClose(pdf.eval(), np.exp(expected_log_pdf))
def testGammaLogPDFMultidimensional(self): def testGammaLogPDFMultidimensional(self):
@ -71,15 +86,16 @@ class GammaTest(test.TestCase):
beta_v = np.array([3.0, 4.0]) beta_v = np.array([3.0, 4.0])
x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
log_pdf = gamma.log_prob(x) log_pdf = gamma.log_prob(x)
log_pdf_values = log_pdf.eval() log_pdf_values = log_pdf.eval()
self.assertEqual(log_pdf.get_shape(), (6, 2)) self.assertEqual(log_pdf.get_shape(), (6, 2))
self.assertAllClose(log_pdf_values, expected_log_pdf)
pdf = gamma.prob(x) pdf = gamma.prob(x)
pdf_values = pdf.eval() pdf_values = pdf.eval()
self.assertEqual(pdf.get_shape(), (6, 2)) self.assertEqual(pdf.get_shape(), (6, 2))
if not stats:
return
expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
self.assertAllClose(log_pdf_values, expected_log_pdf)
self.assertAllClose(pdf_values, np.exp(expected_log_pdf)) self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
def testGammaLogPDFMultidimensionalBroadcasting(self): def testGammaLogPDFMultidimensionalBroadcasting(self):
@ -91,15 +107,17 @@ class GammaTest(test.TestCase):
beta_v = 3.0 beta_v = 3.0
x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
log_pdf = gamma.log_prob(x) log_pdf = gamma.log_prob(x)
log_pdf_values = log_pdf.eval() log_pdf_values = log_pdf.eval()
self.assertEqual(log_pdf.get_shape(), (6, 2)) self.assertEqual(log_pdf.get_shape(), (6, 2))
self.assertAllClose(log_pdf_values, expected_log_pdf)
pdf = gamma.prob(x) pdf = gamma.prob(x)
pdf_values = pdf.eval() pdf_values = pdf.eval()
self.assertEqual(pdf.get_shape(), (6, 2)) self.assertEqual(pdf.get_shape(), (6, 2))
if not stats:
return
expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
self.assertAllClose(log_pdf_values, expected_log_pdf)
self.assertAllClose(pdf_values, np.exp(expected_log_pdf)) self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
def testGammaCDF(self): def testGammaCDF(self):
@ -112,10 +130,11 @@ class GammaTest(test.TestCase):
x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
expected_cdf = stats.gamma.cdf(x, alpha_v, scale=1 / beta_v)
cdf = gamma.cdf(x) cdf = gamma.cdf(x)
self.assertEqual(cdf.get_shape(), (6,)) self.assertEqual(cdf.get_shape(), (6,))
if not stats:
return
expected_cdf = stats.gamma.cdf(x, alpha_v, scale=1 / beta_v)
self.assertAllClose(cdf.eval(), expected_cdf) self.assertAllClose(cdf.eval(), expected_cdf)
def testGammaMean(self): def testGammaMean(self):
@ -123,8 +142,10 @@ class GammaTest(test.TestCase):
alpha_v = np.array([1.0, 3.0, 2.5]) alpha_v = np.array([1.0, 3.0, 2.5])
beta_v = np.array([1.0, 4.0, 5.0]) beta_v = np.array([1.0, 4.0, 5.0])
gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
expected_means = stats.gamma.mean(alpha_v, scale=1 / beta_v)
self.assertEqual(gamma.mean().get_shape(), (3,)) self.assertEqual(gamma.mean().get_shape(), (3,))
if not stats:
return
expected_means = stats.gamma.mean(alpha_v, scale=1 / beta_v)
self.assertAllClose(gamma.mean().eval(), expected_means) self.assertAllClose(gamma.mean().eval(), expected_means)
def testGammaModeAllowNanStatsIsFalseWorksWhenAllBatchMembersAreDefined(self): def testGammaModeAllowNanStatsIsFalseWorksWhenAllBatchMembersAreDefined(self):
@ -165,8 +186,10 @@ class GammaTest(test.TestCase):
alpha_v = np.array([1.0, 3.0, 2.5]) alpha_v = np.array([1.0, 3.0, 2.5])
beta_v = np.array([1.0, 4.0, 5.0]) beta_v = np.array([1.0, 4.0, 5.0])
gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
expected_variances = stats.gamma.var(alpha_v, scale=1 / beta_v)
self.assertEqual(gamma.variance().get_shape(), (3,)) self.assertEqual(gamma.variance().get_shape(), (3,))
if not stats:
return
expected_variances = stats.gamma.var(alpha_v, scale=1 / beta_v)
self.assertAllClose(gamma.variance().eval(), expected_variances) self.assertAllClose(gamma.variance().eval(), expected_variances)
def testGammaStd(self): def testGammaStd(self):
@ -174,17 +197,21 @@ class GammaTest(test.TestCase):
alpha_v = np.array([1.0, 3.0, 2.5]) alpha_v = np.array([1.0, 3.0, 2.5])
beta_v = np.array([1.0, 4.0, 5.0]) beta_v = np.array([1.0, 4.0, 5.0])
gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
expected_stddev = stats.gamma.std(alpha_v, scale=1. / beta_v)
self.assertEqual(gamma.stddev().get_shape(), (3,)) self.assertEqual(gamma.stddev().get_shape(), (3,))
if not stats:
return
expected_stddev = stats.gamma.std(alpha_v, scale=1. / beta_v)
self.assertAllClose(gamma.stddev().eval(), expected_stddev) self.assertAllClose(gamma.stddev().eval(), expected_stddev)
def testGammaEntropy(self): def testGammaEntropy(self):
with self.test_session(): with self.test_session():
alpha_v = np.array([1.0, 3.0, 2.5]) alpha_v = np.array([1.0, 3.0, 2.5])
beta_v = np.array([1.0, 4.0, 5.0]) beta_v = np.array([1.0, 4.0, 5.0])
expected_entropy = stats.gamma.entropy(alpha_v, scale=1 / beta_v)
gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
self.assertEqual(gamma.entropy().get_shape(), (3,)) self.assertEqual(gamma.entropy().get_shape(), (3,))
if not stats:
return
expected_entropy = stats.gamma.entropy(alpha_v, scale=1 / beta_v)
self.assertAllClose(gamma.entropy().eval(), expected_entropy) self.assertAllClose(gamma.entropy().eval(), expected_entropy)
def testGammaSampleSmallAlpha(self): def testGammaSampleSmallAlpha(self):
@ -199,6 +226,9 @@ class GammaTest(test.TestCase):
sample_values = samples.eval() sample_values = samples.eval()
self.assertEqual(samples.get_shape(), (n,)) self.assertEqual(samples.get_shape(), (n,))
self.assertEqual(sample_values.shape, (n,)) self.assertEqual(sample_values.shape, (n,))
self.assertTrue(self._kstest(alpha_v, beta_v, sample_values))
if not stats:
return
self.assertAllClose( self.assertAllClose(
sample_values.mean(), sample_values.mean(),
stats.gamma.mean( stats.gamma.mean(
@ -208,7 +238,6 @@ class GammaTest(test.TestCase):
sample_values.var(), sample_values.var(),
stats.gamma.var(alpha_v, scale=1 / beta_v), stats.gamma.var(alpha_v, scale=1 / beta_v),
atol=.15) atol=.15)
self.assertTrue(self._kstest(alpha_v, beta_v, sample_values))
def testGammaSample(self): def testGammaSample(self):
with session.Session(): with session.Session():
@ -222,6 +251,9 @@ class GammaTest(test.TestCase):
sample_values = samples.eval() sample_values = samples.eval()
self.assertEqual(samples.get_shape(), (n,)) self.assertEqual(samples.get_shape(), (n,))
self.assertEqual(sample_values.shape, (n,)) self.assertEqual(sample_values.shape, (n,))
self.assertTrue(self._kstest(alpha_v, beta_v, sample_values))
if not stats:
return
self.assertAllClose( self.assertAllClose(
sample_values.mean(), sample_values.mean(),
stats.gamma.mean( stats.gamma.mean(
@ -231,7 +263,6 @@ class GammaTest(test.TestCase):
sample_values.var(), sample_values.var(),
stats.gamma.var(alpha_v, scale=1 / beta_v), stats.gamma.var(alpha_v, scale=1 / beta_v),
atol=.15) atol=.15)
self.assertTrue(self._kstest(alpha_v, beta_v, sample_values))
def testGammaSampleMultiDimensional(self): def testGammaSampleMultiDimensional(self):
with session.Session(): with session.Session():
@ -246,6 +277,8 @@ class GammaTest(test.TestCase):
zeros = np.zeros_like(alpha_v + beta_v) # 10 x 100 zeros = np.zeros_like(alpha_v + beta_v) # 10 x 100
alpha_bc = alpha_v + zeros alpha_bc = alpha_v + zeros
beta_bc = beta_v + zeros beta_bc = beta_v + zeros
if not stats:
return
self.assertAllClose( self.assertAllClose(
sample_values.mean(axis=0), sample_values.mean(axis=0),
stats.gamma.mean( stats.gamma.mean(
@ -266,6 +299,8 @@ class GammaTest(test.TestCase):
def _kstest(self, alpha, beta, samples): def _kstest(self, alpha, beta, samples):
# Uses the Kolmogorov-Smirnov test for goodness of fit. # Uses the Kolmogorov-Smirnov test for goodness of fit.
if not stats:
return True # If we can't test, return that the test passes.
ks, _ = stats.kstest(samples, stats.gamma(alpha, scale=1 / beta).cdf) ks, _ = stats.kstest(samples, stats.gamma(alpha, scale=1 / beta).cdf)
# Return True when the test passes. # Return True when the test passes.
return ks < 0.02 return ks < 0.02
@ -279,6 +314,12 @@ class GammaTest(test.TestCase):
sample_vals, pdf_vals = sess.run([samples, pdfs]) sample_vals, pdf_vals = sess.run([samples, pdfs])
self.assertEqual(samples.get_shape(), (num, 2, 2)) self.assertEqual(samples.get_shape(), (num, 2, 2))
self.assertEqual(pdfs.get_shape(), (num, 2, 2)) self.assertEqual(pdfs.get_shape(), (num, 2, 2))
self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02)
self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.02)
self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.02)
self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.02)
if not stats:
return
self.assertAllClose( self.assertAllClose(
stats.gamma.mean( stats.gamma.mean(
[[7., 11.], [7., 11.]], scale=1 / np.array([[5., 5.], [6., 6.]])), [[7., 11.], [7., 11.]], scale=1 / np.array([[5., 5.], [6., 6.]])),
@ -289,10 +330,6 @@ class GammaTest(test.TestCase):
scale=1 / np.array([[5., 5.], [6., 6.]])), scale=1 / np.array([[5., 5.], [6., 6.]])),
sample_vals.var(axis=0), sample_vals.var(axis=0),
atol=.1) atol=.1)
self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02)
self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.02)
self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.02)
self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.02)
def _assertIntegral(self, sample_vals, pdf_vals, err=1e-3): def _assertIntegral(self, sample_vals, pdf_vals, err=1e-3):
s_p = zip(sample_vals, pdf_vals) s_p = zip(sample_vals, pdf_vals)
@ -350,6 +387,10 @@ class GammaTest(test.TestCase):
# Execute graph. # Execute graph.
[kl_sample_, kl_actual_] = sess.run([kl_sample, kl_actual]) [kl_sample_, kl_actual_] = sess.run([kl_sample, kl_actual])
self.assertEqual(beta0.shape, kl_actual.get_shape())
if not special:
return
kl_expected = ((alpha0 - alpha1) * special.digamma(alpha0) kl_expected = ((alpha0 - alpha1) * special.digamma(alpha0)
+ special.gammaln(alpha1) + special.gammaln(alpha1)
- special.gammaln(alpha0) - special.gammaln(alpha0)
@ -357,7 +398,6 @@ class GammaTest(test.TestCase):
- alpha1 * np.log(beta1) - alpha1 * np.log(beta1)
+ alpha0 * (beta1 / beta0 - 1.)) + alpha0 * (beta1 / beta0 - 1.))
self.assertEqual(beta0.shape, kl_actual.get_shape())
self.assertAllClose(kl_expected, kl_actual_, atol=0., rtol=1e-6) self.assertAllClose(kl_expected, kl_actual_, atol=0., rtol=1e-6)
self.assertAllClose(kl_sample_, kl_actual_, atol=0., rtol=1e-2) self.assertAllClose(kl_sample_, kl_actual_, atol=0., rtol=1e-2)

View File

@ -17,15 +17,31 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import importlib
import numpy as np import numpy as np
from scipy import stats
from tensorflow.contrib.distributions.python.ops import laplace as laplace_lib
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import nn_ops from tensorflow.python.ops import nn_ops
from tensorflow.python.ops.distributions import laplace as laplace_lib
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
def try_import(name): # pylint: disable=invalid-name
module = None
try:
module = importlib.import_module(name)
except ImportError as e:
tf_logging.warning("Could not import %s: %s" % (name, str(e)))
return module
stats = try_import("scipy.stats")
class LaplaceTest(test.TestCase): class LaplaceTest(test.TestCase):
@ -49,9 +65,11 @@ class LaplaceTest(test.TestCase):
scale_v = 3.0 scale_v = 3.0
x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
laplace = laplace_lib.Laplace(loc=loc, scale=scale) laplace = laplace_lib.Laplace(loc=loc, scale=scale)
expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v)
log_pdf = laplace.log_prob(x) log_pdf = laplace.log_prob(x)
self.assertEqual(log_pdf.get_shape(), (6,)) self.assertEqual(log_pdf.get_shape(), (6,))
if not stats:
return
expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v)
self.assertAllClose(log_pdf.eval(), expected_log_pdf) self.assertAllClose(log_pdf.eval(), expected_log_pdf)
pdf = laplace.prob(x) pdf = laplace.prob(x)
@ -67,15 +85,17 @@ class LaplaceTest(test.TestCase):
scale_v = np.array([3.0, 4.0]) scale_v = np.array([3.0, 4.0])
x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
laplace = laplace_lib.Laplace(loc=loc, scale=scale) laplace = laplace_lib.Laplace(loc=loc, scale=scale)
expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v)
log_pdf = laplace.log_prob(x) log_pdf = laplace.log_prob(x)
log_pdf_values = log_pdf.eval() log_pdf_values = log_pdf.eval()
self.assertEqual(log_pdf.get_shape(), (6, 2)) self.assertEqual(log_pdf.get_shape(), (6, 2))
self.assertAllClose(log_pdf_values, expected_log_pdf)
pdf = laplace.prob(x) pdf = laplace.prob(x)
pdf_values = pdf.eval() pdf_values = pdf.eval()
self.assertEqual(pdf.get_shape(), (6, 2)) self.assertEqual(pdf.get_shape(), (6, 2))
if not stats:
return
expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v)
self.assertAllClose(log_pdf_values, expected_log_pdf)
self.assertAllClose(pdf_values, np.exp(expected_log_pdf)) self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
def testLaplaceLogPDFMultidimensionalBroadcasting(self): def testLaplaceLogPDFMultidimensionalBroadcasting(self):
@ -87,15 +107,17 @@ class LaplaceTest(test.TestCase):
scale_v = 3.0 scale_v = 3.0
x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
laplace = laplace_lib.Laplace(loc=loc, scale=scale) laplace = laplace_lib.Laplace(loc=loc, scale=scale)
expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v)
log_pdf = laplace.log_prob(x) log_pdf = laplace.log_prob(x)
log_pdf_values = log_pdf.eval() log_pdf_values = log_pdf.eval()
self.assertEqual(log_pdf.get_shape(), (6, 2)) self.assertEqual(log_pdf.get_shape(), (6, 2))
self.assertAllClose(log_pdf_values, expected_log_pdf)
pdf = laplace.prob(x) pdf = laplace.prob(x)
pdf_values = pdf.eval() pdf_values = pdf.eval()
self.assertEqual(pdf.get_shape(), (6, 2)) self.assertEqual(pdf.get_shape(), (6, 2))
if not stats:
return
expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v)
self.assertAllClose(log_pdf_values, expected_log_pdf)
self.assertAllClose(pdf_values, np.exp(expected_log_pdf)) self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
def testLaplaceCDF(self): def testLaplaceCDF(self):
@ -108,10 +130,12 @@ class LaplaceTest(test.TestCase):
x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
laplace = laplace_lib.Laplace(loc=loc, scale=scale) laplace = laplace_lib.Laplace(loc=loc, scale=scale)
expected_cdf = stats.laplace.cdf(x, loc_v, scale=scale_v)
cdf = laplace.cdf(x) cdf = laplace.cdf(x)
self.assertEqual(cdf.get_shape(), (6,)) self.assertEqual(cdf.get_shape(), (6,))
if not stats:
return
expected_cdf = stats.laplace.cdf(x, loc_v, scale=scale_v)
self.assertAllClose(cdf.eval(), expected_cdf) self.assertAllClose(cdf.eval(), expected_cdf)
def testLaplaceLogCDF(self): def testLaplaceLogCDF(self):
@ -124,10 +148,12 @@ class LaplaceTest(test.TestCase):
x = np.array([-2.5, 2.5, -4.0, 0.1, 1.0, 2.0], dtype=np.float32) x = np.array([-2.5, 2.5, -4.0, 0.1, 1.0, 2.0], dtype=np.float32)
laplace = laplace_lib.Laplace(loc=loc, scale=scale) laplace = laplace_lib.Laplace(loc=loc, scale=scale)
expected_cdf = stats.laplace.logcdf(x, loc_v, scale=scale_v)
cdf = laplace.log_cdf(x) cdf = laplace.log_cdf(x)
self.assertEqual(cdf.get_shape(), (6,)) self.assertEqual(cdf.get_shape(), (6,))
if not stats:
return
expected_cdf = stats.laplace.logcdf(x, loc_v, scale=scale_v)
self.assertAllClose(cdf.eval(), expected_cdf) self.assertAllClose(cdf.eval(), expected_cdf)
def testLaplaceLogSurvivalFunction(self): def testLaplaceLogSurvivalFunction(self):
@ -140,10 +166,12 @@ class LaplaceTest(test.TestCase):
x = np.array([-2.5, 2.5, -4.0, 0.1, 1.0, 2.0], dtype=np.float32) x = np.array([-2.5, 2.5, -4.0, 0.1, 1.0, 2.0], dtype=np.float32)
laplace = laplace_lib.Laplace(loc=loc, scale=scale) laplace = laplace_lib.Laplace(loc=loc, scale=scale)
expected_sf = stats.laplace.logsf(x, loc_v, scale=scale_v)
sf = laplace.log_survival_function(x) sf = laplace.log_survival_function(x)
self.assertEqual(sf.get_shape(), (6,)) self.assertEqual(sf.get_shape(), (6,))
if not stats:
return
expected_sf = stats.laplace.logsf(x, loc_v, scale=scale_v)
self.assertAllClose(sf.eval(), expected_sf) self.assertAllClose(sf.eval(), expected_sf)
def testLaplaceMean(self): def testLaplaceMean(self):
@ -151,8 +179,10 @@ class LaplaceTest(test.TestCase):
loc_v = np.array([1.0, 3.0, 2.5]) loc_v = np.array([1.0, 3.0, 2.5])
scale_v = np.array([1.0, 4.0, 5.0]) scale_v = np.array([1.0, 4.0, 5.0])
laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v) laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
expected_means = stats.laplace.mean(loc_v, scale=scale_v)
self.assertEqual(laplace.mean().get_shape(), (3,)) self.assertEqual(laplace.mean().get_shape(), (3,))
if not stats:
return
expected_means = stats.laplace.mean(loc_v, scale=scale_v)
self.assertAllClose(laplace.mean().eval(), expected_means) self.assertAllClose(laplace.mean().eval(), expected_means)
def testLaplaceMode(self): def testLaplaceMode(self):
@ -168,8 +198,10 @@ class LaplaceTest(test.TestCase):
loc_v = np.array([1.0, 3.0, 2.5]) loc_v = np.array([1.0, 3.0, 2.5])
scale_v = np.array([1.0, 4.0, 5.0]) scale_v = np.array([1.0, 4.0, 5.0])
laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v) laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
expected_variances = stats.laplace.var(loc_v, scale=scale_v)
self.assertEqual(laplace.variance().get_shape(), (3,)) self.assertEqual(laplace.variance().get_shape(), (3,))
if not stats:
return
expected_variances = stats.laplace.var(loc_v, scale=scale_v)
self.assertAllClose(laplace.variance().eval(), expected_variances) self.assertAllClose(laplace.variance().eval(), expected_variances)
def testLaplaceStd(self): def testLaplaceStd(self):
@ -177,17 +209,21 @@ class LaplaceTest(test.TestCase):
loc_v = np.array([1.0, 3.0, 2.5]) loc_v = np.array([1.0, 3.0, 2.5])
scale_v = np.array([1.0, 4.0, 5.0]) scale_v = np.array([1.0, 4.0, 5.0])
laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v) laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
expected_stddev = stats.laplace.std(loc_v, scale=scale_v)
self.assertEqual(laplace.stddev().get_shape(), (3,)) self.assertEqual(laplace.stddev().get_shape(), (3,))
if not stats:
return
expected_stddev = stats.laplace.std(loc_v, scale=scale_v)
self.assertAllClose(laplace.stddev().eval(), expected_stddev) self.assertAllClose(laplace.stddev().eval(), expected_stddev)
def testLaplaceEntropy(self): def testLaplaceEntropy(self):
with self.test_session(): with self.test_session():
loc_v = np.array([1.0, 3.0, 2.5]) loc_v = np.array([1.0, 3.0, 2.5])
scale_v = np.array([1.0, 4.0, 5.0]) scale_v = np.array([1.0, 4.0, 5.0])
expected_entropy = stats.laplace.entropy(loc_v, scale=scale_v)
laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v) laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
self.assertEqual(laplace.entropy().get_shape(), (3,)) self.assertEqual(laplace.entropy().get_shape(), (3,))
if not stats:
return
expected_entropy = stats.laplace.entropy(loc_v, scale=scale_v)
self.assertAllClose(laplace.entropy().eval(), expected_entropy) self.assertAllClose(laplace.entropy().eval(), expected_entropy)
def testLaplaceSample(self): def testLaplaceSample(self):
@ -202,6 +238,8 @@ class LaplaceTest(test.TestCase):
sample_values = samples.eval() sample_values = samples.eval()
self.assertEqual(samples.get_shape(), (n,)) self.assertEqual(samples.get_shape(), (n,))
self.assertEqual(sample_values.shape, (n,)) self.assertEqual(sample_values.shape, (n,))
if not stats:
return
self.assertAllClose( self.assertAllClose(
sample_values.mean(), sample_values.mean(),
stats.laplace.mean( stats.laplace.mean(
@ -228,6 +266,8 @@ class LaplaceTest(test.TestCase):
zeros = np.zeros_like(loc_v + scale_v) # 10 x 100 zeros = np.zeros_like(loc_v + scale_v) # 10 x 100
loc_bc = loc_v + zeros loc_bc = loc_v + zeros
scale_bc = scale_v + zeros scale_bc = scale_v + zeros
if not stats:
return
self.assertAllClose( self.assertAllClose(
sample_values.mean(axis=0), sample_values.mean(axis=0),
stats.laplace.mean( stats.laplace.mean(
@ -250,6 +290,8 @@ class LaplaceTest(test.TestCase):
def _kstest(self, loc, scale, samples): def _kstest(self, loc, scale, samples):
# Uses the Kolmogorov-Smirnov test for goodness of fit. # Uses the Kolmogorov-Smirnov test for goodness of fit.
if not stats:
return True # If scipy isn't available, return "True" for passing
ks, _ = stats.kstest(samples, stats.laplace(loc, scale=scale).cdf) ks, _ = stats.kstest(samples, stats.laplace(loc, scale=scale).cdf)
# Return True when the test passes. # Return True when the test passes.
return ks < 0.02 return ks < 0.02
@ -263,6 +305,12 @@ class LaplaceTest(test.TestCase):
sample_vals, pdf_vals = sess.run([samples, pdfs]) sample_vals, pdf_vals = sess.run([samples, pdfs])
self.assertEqual(samples.get_shape(), (num, 2, 2)) self.assertEqual(samples.get_shape(), (num, 2, 2))
self.assertEqual(pdfs.get_shape(), (num, 2, 2)) self.assertEqual(pdfs.get_shape(), (num, 2, 2))
self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02)
self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.02)
self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.02)
self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.02)
if not stats:
return
self.assertAllClose( self.assertAllClose(
stats.laplace.mean( stats.laplace.mean(
[[7., 11.], [7., 11.]], scale=np.array([[5., 5.], [6., 6.]])), [[7., 11.], [7., 11.]], scale=np.array([[5., 5.], [6., 6.]])),
@ -275,10 +323,6 @@ class LaplaceTest(test.TestCase):
sample_vals.var(axis=0), sample_vals.var(axis=0),
rtol=0.05, rtol=0.05,
atol=0.) atol=0.)
self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02)
self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.02)
self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.02)
self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.02)
def _assertIntegral(self, sample_vals, pdf_vals, err=1e-3): def _assertIntegral(self, sample_vals, pdf_vals, err=1e-3):
s_p = zip(sample_vals, pdf_vals) s_p = zip(sample_vals, pdf_vals)

View File

@ -17,15 +17,14 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.contrib import distributions
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import multinomial
from tensorflow.python.platform import test from tensorflow.python.platform import test
ds = distributions
class MultinomialTest(test.TestCase): class MultinomialTest(test.TestCase):
@ -35,7 +34,7 @@ class MultinomialTest(test.TestCase):
def testSimpleShapes(self): def testSimpleShapes(self):
with self.test_session(): with self.test_session():
p = [.1, .3, .6] p = [.1, .3, .6]
dist = ds.Multinomial(total_count=1., probs=p) dist = multinomial.Multinomial(total_count=1., probs=p)
self.assertEqual(3, dist.event_shape_tensor().eval()) self.assertEqual(3, dist.event_shape_tensor().eval())
self.assertAllEqual([], dist.batch_shape_tensor().eval()) self.assertAllEqual([], dist.batch_shape_tensor().eval())
self.assertEqual(tensor_shape.TensorShape([3]), dist.event_shape) self.assertEqual(tensor_shape.TensorShape([3]), dist.event_shape)
@ -45,7 +44,7 @@ class MultinomialTest(test.TestCase):
with self.test_session(): with self.test_session():
p = 0.5 * np.ones([3, 2, 2], dtype=np.float32) p = 0.5 * np.ones([3, 2, 2], dtype=np.float32)
n = [[3., 2], [4, 5], [6, 7]] n = [[3., 2], [4, 5], [6, 7]]
dist = ds.Multinomial(total_count=n, probs=p) dist = multinomial.Multinomial(total_count=n, probs=p)
self.assertEqual(2, dist.event_shape_tensor().eval()) self.assertEqual(2, dist.event_shape_tensor().eval())
self.assertAllEqual([3, 2], dist.batch_shape_tensor().eval()) self.assertAllEqual([3, 2], dist.batch_shape_tensor().eval())
self.assertEqual(tensor_shape.TensorShape([2]), dist.event_shape) self.assertEqual(tensor_shape.TensorShape([2]), dist.event_shape)
@ -55,14 +54,14 @@ class MultinomialTest(test.TestCase):
p = [[0.1, 0.2, 0.7], [0.2, 0.3, 0.5]] p = [[0.1, 0.2, 0.7], [0.2, 0.3, 0.5]]
n = [[3.], [4]] n = [[3.], [4]]
with self.test_session(): with self.test_session():
dist = ds.Multinomial(total_count=n, probs=p) dist = multinomial.Multinomial(total_count=n, probs=p)
self.assertEqual((2, 1), dist.total_count.get_shape()) self.assertEqual((2, 1), dist.total_count.get_shape())
self.assertAllClose(n, dist.total_count.eval()) self.assertAllClose(n, dist.total_count.eval())
def testP(self): def testP(self):
p = [[0.1, 0.2, 0.7]] p = [[0.1, 0.2, 0.7]]
with self.test_session(): with self.test_session():
dist = ds.Multinomial(total_count=3., probs=p) dist = multinomial.Multinomial(total_count=3., probs=p)
self.assertEqual((1, 3), dist.probs.get_shape()) self.assertEqual((1, 3), dist.probs.get_shape())
self.assertEqual((1, 3), dist.logits.get_shape()) self.assertEqual((1, 3), dist.logits.get_shape())
self.assertAllClose(p, dist.probs.eval()) self.assertAllClose(p, dist.probs.eval())
@ -71,7 +70,7 @@ class MultinomialTest(test.TestCase):
p = np.array([[0.1, 0.2, 0.7]], dtype=np.float32) p = np.array([[0.1, 0.2, 0.7]], dtype=np.float32)
logits = np.log(p) - 50. logits = np.log(p) - 50.
with self.test_session(): with self.test_session():
multinom = ds.Multinomial(total_count=3., logits=logits) multinom = multinomial.Multinomial(total_count=3., logits=logits)
self.assertEqual((1, 3), multinom.probs.get_shape()) self.assertEqual((1, 3), multinom.probs.get_shape())
self.assertEqual((1, 3), multinom.logits.get_shape()) self.assertEqual((1, 3), multinom.logits.get_shape())
self.assertAllClose(p, multinom.probs.eval()) self.assertAllClose(p, multinom.probs.eval())
@ -81,7 +80,7 @@ class MultinomialTest(test.TestCase):
p = [[0.1, 0.2, 0.7]] p = [[0.1, 0.2, 0.7]]
n = [[5.]] n = [[5.]]
with self.test_session(): with self.test_session():
dist = ds.Multinomial(total_count=n, probs=p, validate_args=True) dist = multinomial.Multinomial(total_count=n, probs=p, validate_args=True)
dist.prob([2., 3, 0]).eval() dist.prob([2., 3, 0]).eval()
dist.prob([3., 0, 2]).eval() dist.prob([3., 0, 2]).eval()
with self.assertRaisesOpError("must be non-negative"): with self.assertRaisesOpError("must be non-negative"):
@ -94,7 +93,8 @@ class MultinomialTest(test.TestCase):
n = [[5.]] n = [[5.]]
with self.test_session(): with self.test_session():
# No errors with integer n. # No errors with integer n.
multinom = ds.Multinomial(total_count=n, probs=p, validate_args=True) multinom = multinomial.Multinomial(
total_count=n, probs=p, validate_args=True)
multinom.prob([2., 1, 2]).eval() multinom.prob([2., 1, 2]).eval()
multinom.prob([3., 0, 2]).eval() multinom.prob([3., 0, 2]).eval()
# Counts don't sum to n. # Counts don't sum to n.
@ -106,7 +106,8 @@ class MultinomialTest(test.TestCase):
"cannot contain fractional components."): "cannot contain fractional components."):
multinom.prob(x).eval(feed_dict={x: [1.0, 2.5, 1.5]}) multinom.prob(x).eval(feed_dict={x: [1.0, 2.5, 1.5]})
multinom = ds.Multinomial(total_count=n, probs=p, validate_args=False) multinom = multinomial.Multinomial(
total_count=n, probs=p, validate_args=False)
multinom.prob([1., 2., 2.]).eval() multinom.prob([1., 2., 2.]).eval()
# Non-integer arguments work. # Non-integer arguments work.
multinom.prob([1.0, 2.5, 1.5]).eval() multinom.prob([1.0, 2.5, 1.5]).eval()
@ -116,7 +117,7 @@ class MultinomialTest(test.TestCase):
# Both zero-batches. No broadcast # Both zero-batches. No broadcast
p = [0.5, 0.5] p = [0.5, 0.5]
counts = [1., 0] counts = [1., 0]
pmf = ds.Multinomial(total_count=1., probs=p).prob(counts) pmf = multinomial.Multinomial(total_count=1., probs=p).prob(counts)
self.assertAllClose(0.5, pmf.eval()) self.assertAllClose(0.5, pmf.eval())
self.assertEqual((), pmf.get_shape()) self.assertEqual((), pmf.get_shape())
@ -125,7 +126,7 @@ class MultinomialTest(test.TestCase):
# Both zero-batches. No broadcast # Both zero-batches. No broadcast
p = [0.1, 0.9] p = [0.1, 0.9]
counts = [3., 2] counts = [3., 2]
dist = ds.Multinomial(total_count=5., probs=p) dist = multinomial.Multinomial(total_count=5., probs=p)
pmf = dist.prob(counts) pmf = dist.prob(counts)
# 5 choose 3 = 5 choose 2 = 10. 10 * (.9)^2 * (.1)^3 = 81/10000. # 5 choose 3 = 5 choose 2 = 10. 10 * (.9)^2 * (.1)^3 = 81/10000.
self.assertAllClose(81. / 10000, pmf.eval()) self.assertAllClose(81. / 10000, pmf.eval())
@ -135,7 +136,7 @@ class MultinomialTest(test.TestCase):
with self.test_session(): with self.test_session():
p = [[0.1, 0.9]] p = [[0.1, 0.9]]
counts = [[1., 0], [0, 1]] counts = [[1., 0], [0, 1]]
pmf = ds.Multinomial(total_count=1., probs=p).prob(counts) pmf = multinomial.Multinomial(total_count=1., probs=p).prob(counts)
self.assertAllClose([0.1, 0.9], pmf.eval()) self.assertAllClose([0.1, 0.9], pmf.eval())
self.assertEqual((2), pmf.get_shape()) self.assertEqual((2), pmf.get_shape())
@ -143,7 +144,7 @@ class MultinomialTest(test.TestCase):
with self.test_session(): with self.test_session():
p = [0.1, 0.9] p = [0.1, 0.9]
counts = [[1., 0], [0, 1]] counts = [[1., 0], [0, 1]]
pmf = ds.Multinomial(total_count=1., probs=p).prob(counts) pmf = multinomial.Multinomial(total_count=1., probs=p).prob(counts)
self.assertAllClose([0.1, 0.9], pmf.eval()) self.assertAllClose([0.1, 0.9], pmf.eval())
self.assertEqual((2), pmf.get_shape()) self.assertEqual((2), pmf.get_shape())
@ -151,7 +152,7 @@ class MultinomialTest(test.TestCase):
with self.test_session(): with self.test_session():
p = [[0.1, 0.9], [0.7, 0.3]] p = [[0.1, 0.9], [0.7, 0.3]]
counts = [[1., 0]] counts = [[1., 0]]
pmf = ds.Multinomial(total_count=1., probs=p).prob(counts) pmf = multinomial.Multinomial(total_count=1., probs=p).prob(counts)
self.assertAllClose(pmf.eval(), [0.1, 0.7]) self.assertAllClose(pmf.eval(), [0.1, 0.7])
self.assertEqual((2), pmf.get_shape()) self.assertEqual((2), pmf.get_shape())
@ -159,7 +160,7 @@ class MultinomialTest(test.TestCase):
with self.test_session(): with self.test_session():
p = [[0.1, 0.9], [0.7, 0.3]] p = [[0.1, 0.9], [0.7, 0.3]]
counts = [1., 0] counts = [1., 0]
pmf = ds.Multinomial(total_count=1., probs=p).prob(counts) pmf = multinomial.Multinomial(total_count=1., probs=p).prob(counts)
self.assertAllClose(pmf.eval(), [0.1, 0.7]) self.assertAllClose(pmf.eval(), [0.1, 0.7])
self.assertEqual(pmf.get_shape(), (2)) self.assertEqual(pmf.get_shape(), (2))
@ -171,7 +172,7 @@ class MultinomialTest(test.TestCase):
n = [[3., 3], [3, 3]] n = [[3., 3], [3, 3]]
# [2] # [2]
counts = [2., 1] counts = [2., 1]
pmf = ds.Multinomial(total_count=n, probs=p).prob(counts) pmf = multinomial.Multinomial(total_count=n, probs=p).prob(counts)
pmf.eval() pmf.eval()
self.assertEqual(pmf.get_shape(), (2, 2)) self.assertEqual(pmf.get_shape(), (2, 2))
@ -180,7 +181,7 @@ class MultinomialTest(test.TestCase):
p = [0.1, 0.9] p = [0.1, 0.9]
counts = [3., 2] counts = [3., 2]
n = np.full([4, 3], 5., dtype=np.float32) n = np.full([4, 3], 5., dtype=np.float32)
pmf = ds.Multinomial(total_count=n, probs=p).prob(counts) pmf = multinomial.Multinomial(total_count=n, probs=p).prob(counts)
pmf.eval() pmf.eval()
self.assertEqual((4, 3), pmf.get_shape()) self.assertEqual((4, 3), pmf.get_shape())
@ -188,7 +189,7 @@ class MultinomialTest(test.TestCase):
with self.test_session(): with self.test_session():
n = 5. n = 5.
p = [0.1, 0.2, 0.7] p = [0.1, 0.2, 0.7]
dist = ds.Multinomial(total_count=n, probs=p) dist = multinomial.Multinomial(total_count=n, probs=p)
expected_means = 5 * np.array(p, dtype=np.float32) expected_means = 5 * np.array(p, dtype=np.float32)
self.assertEqual((3,), dist.mean().get_shape()) self.assertEqual((3,), dist.mean().get_shape())
self.assertAllClose(expected_means, dist.mean().eval()) self.assertAllClose(expected_means, dist.mean().eval())
@ -197,7 +198,7 @@ class MultinomialTest(test.TestCase):
with self.test_session(): with self.test_session():
n = 5. n = 5.
p = [0.1, 0.2, 0.7] p = [0.1, 0.2, 0.7]
dist = ds.Multinomial(total_count=n, probs=p) dist = multinomial.Multinomial(total_count=n, probs=p)
expected_covariances = [[9. / 20, -1 / 10, -7 / 20], expected_covariances = [[9. / 20, -1 / 10, -7 / 20],
[-1 / 10, 4 / 5, -7 / 10], [-1 / 10, 4 / 5, -7 / 10],
[-7 / 20, -7 / 10, 21 / 20]] [-7 / 20, -7 / 10, 21 / 20]]
@ -210,7 +211,7 @@ class MultinomialTest(test.TestCase):
n = [5.] * 2 n = [5.] * 2
# Shape [4, 1, 2] # Shape [4, 1, 2]
p = [[[0.1, 0.9]], [[0.1, 0.9]]] * 2 p = [[[0.1, 0.9]], [[0.1, 0.9]]] * 2
dist = ds.Multinomial(total_count=n, probs=p) dist = multinomial.Multinomial(total_count=n, probs=p)
# Shape [2, 2] # Shape [2, 2]
inner_var = [[9. / 20, -9 / 20], [-9 / 20, 9 / 20]] inner_var = [[9. / 20, -9 / 20], [-9 / 20, 9 / 20]]
# Shape [4, 2, 2, 2] # Shape [4, 2, 2, 2]
@ -228,8 +229,8 @@ class MultinomialTest(test.TestCase):
ns2 = np.random.randint(low=1, high=11, size=[6, 1]).astype(np.float32) ns2 = np.random.randint(low=1, high=11, size=[6, 1]).astype(np.float32)
with self.test_session(): with self.test_session():
dist = ds.Multinomial(ns, p) dist = multinomial.Multinomial(ns, p)
dist2 = ds.Multinomial(ns2, p2) dist2 = multinomial.Multinomial(ns2, p2)
covariance = dist.covariance() covariance = dist.covariance()
covariance2 = dist2.covariance() covariance2 = dist2.covariance()
@ -246,7 +247,8 @@ class MultinomialTest(test.TestCase):
# doesn't support different total counts. # doesn't support different total counts.
n = np.float32(5) n = np.float32(5)
with self.test_session() as sess: with self.test_session() as sess:
dist = ds.Multinomial(n, theta) # batch_shape=[2], event_shape=[3] # batch_shape=[2], event_shape=[3]
dist = multinomial.Multinomial(n, theta)
x = dist.sample(int(250e3), seed=1) x = dist.sample(int(250e3), seed=1)
sample_mean = math_ops.reduce_mean(x, 0) sample_mean = math_ops.reduce_mean(x, 0)
x_centered = x - sample_mean[array_ops.newaxis, ...] x_centered = x - sample_mean[array_ops.newaxis, ...]
@ -281,7 +283,7 @@ class MultinomialTest(test.TestCase):
def testSampleUnbiasedNonScalarBatch(self): def testSampleUnbiasedNonScalarBatch(self):
with self.test_session() as sess: with self.test_session() as sess:
dist = ds.Multinomial( dist = multinomial.Multinomial(
total_count=5., total_count=5.,
logits=math_ops.log(2. * self._rng.rand(4, 3, 2).astype(np.float32))) logits=math_ops.log(2. * self._rng.rand(4, 3, 2).astype(np.float32)))
n = int(3e3) n = int(3e3)
@ -310,7 +312,7 @@ class MultinomialTest(test.TestCase):
def testSampleUnbiasedScalarBatch(self): def testSampleUnbiasedScalarBatch(self):
with self.test_session() as sess: with self.test_session() as sess:
dist = ds.Multinomial( dist = multinomial.Multinomial(
total_count=5., total_count=5.,
logits=math_ops.log(2. * self._rng.rand(4).astype(np.float32))) logits=math_ops.log(2. * self._rng.rand(4).astype(np.float32)))
n = int(5e3) n = int(5e3)

View File

@ -18,19 +18,30 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import importlib
import math import math
import numpy as np import numpy as np
from scipy import stats
from tensorflow.contrib import distributions
from tensorflow.contrib.distributions.python.ops import student_t
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import random_seed from tensorflow.python.framework import random_seed
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops from tensorflow.python.ops import nn_ops
from tensorflow.python.ops.distributions import student_t
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
ds = distributions
def try_import(name): # pylint: disable=invalid-name
module = None
try:
module = importlib.import_module(name)
except ImportError as e:
tf_logging.warning("Could not import %s: %s" % (name, str(e)))
return module
stats = try_import("scipy.stats")
class StudentTTest(test.TestCase): class StudentTTest(test.TestCase):
@ -45,7 +56,7 @@ class StudentTTest(test.TestCase):
mu_v = 7. mu_v = 7.
sigma_v = 8. sigma_v = 8.
t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32) t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32)
student = ds.StudentT(df, loc=mu, scale=-sigma) student = student_t.StudentT(df, loc=mu, scale=-sigma)
log_pdf = student.log_prob(t) log_pdf = student.log_prob(t)
self.assertEquals(log_pdf.get_shape(), (6,)) self.assertEquals(log_pdf.get_shape(), (6,))
@ -54,6 +65,9 @@ class StudentTTest(test.TestCase):
self.assertEquals(pdf.get_shape(), (6,)) self.assertEquals(pdf.get_shape(), (6,))
pdf_values = pdf.eval() pdf_values = pdf.eval()
if not stats:
return
expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v) expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v)
expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v) expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v)
self.assertAllClose(expected_log_pdf, log_pdf_values) self.assertAllClose(expected_log_pdf, log_pdf_values)
@ -72,13 +86,16 @@ class StudentTTest(test.TestCase):
mu_v = np.array([3., -3.]) mu_v = np.array([3., -3.])
sigma_v = np.array([np.sqrt(10.), np.sqrt(15.)]) sigma_v = np.array([np.sqrt(10.), np.sqrt(15.)])
t = np.array([[-2.5, 2.5, 4., 0., -1., 2.]], dtype=np.float32).T t = np.array([[-2.5, 2.5, 4., 0., -1., 2.]], dtype=np.float32).T
student = ds.StudentT(df, loc=mu, scale=sigma) student = student_t.StudentT(df, loc=mu, scale=sigma)
log_pdf = student.log_prob(t) log_pdf = student.log_prob(t)
log_pdf_values = log_pdf.eval() log_pdf_values = log_pdf.eval()
self.assertEqual(log_pdf.get_shape(), (6, 2)) self.assertEqual(log_pdf.get_shape(), (6, 2))
pdf = student.prob(t) pdf = student.prob(t)
pdf_values = pdf.eval() pdf_values = pdf.eval()
self.assertEqual(pdf.get_shape(), (6, 2)) self.assertEqual(pdf.get_shape(), (6, 2))
if not stats:
return
expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v) expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v)
expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v) expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v)
self.assertAllClose(expected_log_pdf, log_pdf_values) self.assertAllClose(expected_log_pdf, log_pdf_values)
@ -105,6 +122,8 @@ class StudentTTest(test.TestCase):
self.assertEquals(cdf.get_shape(), (6,)) self.assertEquals(cdf.get_shape(), (6,))
cdf_values = cdf.eval() cdf_values = cdf.eval()
if not stats:
return
expected_log_cdf = stats.t.logcdf(t, df_v, loc=mu_v, scale=sigma_v) expected_log_cdf = stats.t.logcdf(t, df_v, loc=mu_v, scale=sigma_v)
expected_cdf = stats.t.cdf(t, df_v, loc=mu_v, scale=sigma_v) expected_cdf = stats.t.cdf(t, df_v, loc=mu_v, scale=sigma_v)
self.assertAllClose(expected_log_cdf, log_cdf_values, atol=0., rtol=1e-5) self.assertAllClose(expected_log_cdf, log_cdf_values, atol=0., rtol=1e-5)
@ -119,7 +138,7 @@ class StudentTTest(test.TestCase):
mu_v = np.array([[1., -1, 0]]) # 1x3 mu_v = np.array([[1., -1, 0]]) # 1x3
sigma_v = np.array([[1., -2., 3.]]).T # transposed => 3x1 sigma_v = np.array([[1., -2., 3.]]).T # transposed => 3x1
with self.test_session(): with self.test_session():
student = ds.StudentT(df=df_v, loc=mu_v, scale=sigma_v) student = student_t.StudentT(df=df_v, loc=mu_v, scale=sigma_v)
ent = student.entropy() ent = student.entropy()
ent_values = ent.eval() ent_values = ent.eval()
@ -128,6 +147,8 @@ class StudentTTest(test.TestCase):
sigma_bc = np.abs(sigma_v) * ones sigma_bc = np.abs(sigma_v) * ones
mu_bc = ones.T * mu_v mu_bc = ones.T * mu_v
df_bc = ones.T * df_v df_bc = ones.T * df_v
if not stats:
return
expected_entropy = stats.t.entropy( expected_entropy = stats.t.entropy(
np.reshape(df_bc, [-1]), np.reshape(df_bc, [-1]),
loc=np.reshape(mu_bc, [-1]), loc=np.reshape(mu_bc, [-1]),
@ -144,7 +165,7 @@ class StudentTTest(test.TestCase):
mu_v = 3. mu_v = 3.
sigma_v = np.sqrt(10.) sigma_v = np.sqrt(10.)
n = constant_op.constant(200000) n = constant_op.constant(200000)
student = ds.StudentT(df=df, loc=mu, scale=sigma) student = student_t.StudentT(df=df, loc=mu, scale=sigma)
samples = student.sample(n, seed=123456) samples = student.sample(n, seed=123456)
sample_values = samples.eval() sample_values = samples.eval()
n_val = 200000 n_val = 200000
@ -166,11 +187,13 @@ class StudentTTest(test.TestCase):
n = constant_op.constant(100) n = constant_op.constant(100)
random_seed.set_random_seed(654321) random_seed.set_random_seed(654321)
student = ds.StudentT(df=df, loc=mu, scale=sigma, name="student_t1") student = student_t.StudentT(
df=df, loc=mu, scale=sigma, name="student_t1")
samples1 = student.sample(n, seed=123456).eval() samples1 = student.sample(n, seed=123456).eval()
random_seed.set_random_seed(654321) random_seed.set_random_seed(654321)
student2 = ds.StudentT(df=df, loc=mu, scale=sigma, name="student_t2") student2 = student_t.StudentT(
df=df, loc=mu, scale=sigma, name="student_t2")
samples2 = student2.sample(n, seed=123456).eval() samples2 = student2.sample(n, seed=123456).eval()
self.assertAllClose(samples1, samples2) self.assertAllClose(samples1, samples2)
@ -180,7 +203,7 @@ class StudentTTest(test.TestCase):
df_v = [1e-1, 1e-5, 1e-10, 1e-20] df_v = [1e-1, 1e-5, 1e-10, 1e-20]
df = constant_op.constant(df_v) df = constant_op.constant(df_v)
n = constant_op.constant(200000) n = constant_op.constant(200000)
student = ds.StudentT(df=df, loc=1., scale=1.) student = student_t.StudentT(df=df, loc=1., scale=1.)
samples = student.sample(n, seed=123456) samples = student.sample(n, seed=123456)
sample_values = samples.eval() sample_values = samples.eval()
n_val = 200000 n_val = 200000
@ -198,7 +221,7 @@ class StudentTTest(test.TestCase):
mu_v = [3., -3.] mu_v = [3., -3.]
sigma_v = [np.sqrt(10.), np.sqrt(15.)] sigma_v = [np.sqrt(10.), np.sqrt(15.)]
n = constant_op.constant(200000) n = constant_op.constant(200000)
student = ds.StudentT(df=df, loc=mu, scale=sigma) student = student_t.StudentT(df=df, loc=mu, scale=sigma)
samples = student.sample(n, seed=123456) samples = student.sample(n, seed=123456)
sample_values = samples.eval() sample_values = samples.eval()
self.assertEqual(samples.get_shape(), (200000, batch_size, 2)) self.assertEqual(samples.get_shape(), (200000, batch_size, 2))
@ -222,6 +245,8 @@ class StudentTTest(test.TestCase):
def _checkKLApprox(self, df, mu, sigma, samples): def _checkKLApprox(self, df, mu, sigma, samples):
n = samples.size n = samples.size
np.random.seed(137) np.random.seed(137)
if not stats:
return
sample_scipy = stats.t.rvs(df, loc=mu, scale=sigma, size=n) sample_scipy = stats.t.rvs(df, loc=mu, scale=sigma, size=n)
covg = 0.99 covg = 0.99
r = stats.t.interval(covg, df, loc=mu, scale=sigma) r = stats.t.interval(covg, df, loc=mu, scale=sigma)
@ -247,9 +272,9 @@ class StudentTTest(test.TestCase):
self.assertEqual(student.prob(2.).get_shape(), (3,)) self.assertEqual(student.prob(2.).get_shape(), (3,))
self.assertEqual(student.sample(37, seed=123456).get_shape(), (37, 3,)) self.assertEqual(student.sample(37, seed=123456).get_shape(), (37, 3,))
_check(ds.StudentT(df=[2., 3., 4.,], loc=2., scale=1.)) _check(student_t.StudentT(df=[2., 3., 4.,], loc=2., scale=1.))
_check(ds.StudentT(df=7., loc=[2., 3., 4.,], scale=1.)) _check(student_t.StudentT(df=7., loc=[2., 3., 4.,], scale=1.))
_check(ds.StudentT(df=7., loc=3., scale=[2., 3., 4.,])) _check(student_t.StudentT(df=7., loc=3., scale=[2., 3., 4.,]))
def testBroadcastingPdfArgs(self): def testBroadcastingPdfArgs(self):
@ -266,9 +291,9 @@ class StudentTTest(test.TestCase):
xs = xs.T xs = xs.T
_assert_shape(student, xs, (3, 3)) _assert_shape(student, xs, (3, 3))
_check(ds.StudentT(df=[2., 3., 4.,], loc=2., scale=1.)) _check(student_t.StudentT(df=[2., 3., 4.,], loc=2., scale=1.))
_check(ds.StudentT(df=7., loc=[2., 3., 4.,], scale=1.)) _check(student_t.StudentT(df=7., loc=[2., 3., 4.,], scale=1.))
_check(ds.StudentT(df=7., loc=3., scale=[2., 3., 4.,])) _check(student_t.StudentT(df=7., loc=3., scale=[2., 3., 4.,]))
def _check2d(student): def _check2d(student):
_assert_shape(student, 2., (1, 3)) _assert_shape(student, 2., (1, 3))
@ -279,9 +304,9 @@ class StudentTTest(test.TestCase):
xs = xs.T xs = xs.T
_assert_shape(student, xs, (3, 3)) _assert_shape(student, xs, (3, 3))
_check2d(ds.StudentT(df=[[2., 3., 4.,]], loc=2., scale=1.)) _check2d(student_t.StudentT(df=[[2., 3., 4.,]], loc=2., scale=1.))
_check2d(ds.StudentT(df=7., loc=[[2., 3., 4.,]], scale=1.)) _check2d(student_t.StudentT(df=7., loc=[[2., 3., 4.,]], scale=1.))
_check2d(ds.StudentT(df=7., loc=3., scale=[[2., 3., 4.,]])) _check2d(student_t.StudentT(df=7., loc=3., scale=[[2., 3., 4.,]]))
def _check2d_rows(student): def _check2d_rows(student):
_assert_shape(student, 2., (3, 1)) _assert_shape(student, 2., (3, 1))
@ -292,22 +317,23 @@ class StudentTTest(test.TestCase):
xs = xs.T # (3,1) xs = xs.T # (3,1)
_assert_shape(student, xs, (3, 1)) _assert_shape(student, xs, (3, 1))
_check2d_rows(ds.StudentT(df=[[2.], [3.], [4.]], loc=2., scale=1.)) _check2d_rows(student_t.StudentT(df=[[2.], [3.], [4.]], loc=2., scale=1.))
_check2d_rows(ds.StudentT(df=7., loc=[[2.], [3.], [4.]], scale=1.)) _check2d_rows(student_t.StudentT(df=7., loc=[[2.], [3.], [4.]], scale=1.))
_check2d_rows(ds.StudentT(df=7., loc=3., scale=[[2.], [3.], [4.]])) _check2d_rows(student_t.StudentT(df=7., loc=3., scale=[[2.], [3.], [4.]]))
def testMeanAllowNanStatsIsFalseWorksWhenAllBatchMembersAreDefined(self): def testMeanAllowNanStatsIsFalseWorksWhenAllBatchMembersAreDefined(self):
with self.test_session(): with self.test_session():
mu = [1., 3.3, 4.4] mu = [1., 3.3, 4.4]
student = ds.StudentT(df=[3., 5., 7.], loc=mu, scale=[3., 2., 1.]) student = student_t.StudentT(df=[3., 5., 7.], loc=mu, scale=[3., 2., 1.])
mean = student.mean().eval() mean = student.mean().eval()
self.assertAllClose([1., 3.3, 4.4], mean) self.assertAllClose([1., 3.3, 4.4], mean)
def testMeanAllowNanStatsIsFalseRaisesWhenBatchMemberIsUndefined(self): def testMeanAllowNanStatsIsFalseRaisesWhenBatchMemberIsUndefined(self):
with self.test_session(): with self.test_session():
mu = [1., 3.3, 4.4] mu = [1., 3.3, 4.4]
student = ds.StudentT(df=[0.5, 5., 7.], loc=mu, scale=[3., 2., 1.], student = student_t.StudentT(
allow_nan_stats=False) df=[0.5, 5., 7.], loc=mu, scale=[3., 2., 1.],
allow_nan_stats=False)
with self.assertRaisesOpError("x < y"): with self.assertRaisesOpError("x < y"):
student.mean().eval() student.mean().eval()
@ -315,8 +341,9 @@ class StudentTTest(test.TestCase):
with self.test_session(): with self.test_session():
mu = [-2, 0., 1., 3.3, 4.4] mu = [-2, 0., 1., 3.3, 4.4]
sigma = [5., 4., 3., 2., 1.] sigma = [5., 4., 3., 2., 1.]
student = ds.StudentT(df=[0.5, 1., 3., 5., 7.], loc=mu, scale=sigma, student = student_t.StudentT(
allow_nan_stats=True) df=[0.5, 1., 3., 5., 7.], loc=mu, scale=sigma,
allow_nan_stats=True)
mean = student.mean().eval() mean = student.mean().eval()
self.assertAllClose([np.nan, np.nan, 1., 3.3, 4.4], mean) self.assertAllClose([np.nan, np.nan, 1., 3.3, 4.4], mean)
@ -327,7 +354,8 @@ class StudentTTest(test.TestCase):
df = [0.5, 1.5, 3., 5., 7.] df = [0.5, 1.5, 3., 5., 7.]
mu = [-2, 0., 1., 3.3, 4.4] mu = [-2, 0., 1., 3.3, 4.4]
sigma = [5., 4., 3., 2., 1.] sigma = [5., 4., 3., 2., 1.]
student = ds.StudentT(df=df, loc=mu, scale=sigma, allow_nan_stats=True) student = student_t.StudentT(
df=df, loc=mu, scale=sigma, allow_nan_stats=True)
var = student.variance().eval() var = student.variance().eval()
## scipy uses inf for variance when the mean is undefined. When mean is ## scipy uses inf for variance when the mean is undefined. When mean is
# undefined we say variance is undefined as well. So test the first # undefined we say variance is undefined as well. So test the first
@ -336,6 +364,8 @@ class StudentTTest(test.TestCase):
self.assertTrue(np.isnan(var[0])) self.assertTrue(np.isnan(var[0]))
var[0] = np.inf var[0] = np.inf
if not stats:
return
expected_var = [ expected_var = [
stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma) stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)
] ]
@ -348,9 +378,11 @@ class StudentTTest(test.TestCase):
df = [1.5, 3., 5., 7.] df = [1.5, 3., 5., 7.]
mu = [0., 1., 3.3, 4.4] mu = [0., 1., 3.3, 4.4]
sigma = [4., 3., 2., 1.] sigma = [4., 3., 2., 1.]
student = ds.StudentT(df=df, loc=mu, scale=sigma) student = student_t.StudentT(df=df, loc=mu, scale=sigma)
var = student.variance().eval() var = student.variance().eval()
if not stats:
return
expected_var = [ expected_var = [
stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma) stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)
] ]
@ -359,13 +391,15 @@ class StudentTTest(test.TestCase):
def testVarianceAllowNanStatsFalseRaisesForUndefinedBatchMembers(self): def testVarianceAllowNanStatsFalseRaisesForUndefinedBatchMembers(self):
with self.test_session(): with self.test_session():
# df <= 1 ==> variance not defined # df <= 1 ==> variance not defined
student = ds.StudentT(df=1., loc=0., scale=1., allow_nan_stats=False) student = student_t.StudentT(
df=1., loc=0., scale=1., allow_nan_stats=False)
with self.assertRaisesOpError("x < y"): with self.assertRaisesOpError("x < y"):
student.variance().eval() student.variance().eval()
with self.test_session(): with self.test_session():
# df <= 1 ==> variance not defined # df <= 1 ==> variance not defined
student = ds.StudentT(df=0.5, loc=0., scale=1., allow_nan_stats=False) student = student_t.StudentT(
df=0.5, loc=0., scale=1., allow_nan_stats=False)
with self.assertRaisesOpError("x < y"): with self.assertRaisesOpError("x < y"):
student.variance().eval() student.variance().eval()
@ -375,11 +409,13 @@ class StudentTTest(test.TestCase):
df = [3.5, 5., 3., 5., 7.] df = [3.5, 5., 3., 5., 7.]
mu = [-2.2] mu = [-2.2]
sigma = [5., 4., 3., 2., 1.] sigma = [5., 4., 3., 2., 1.]
student = ds.StudentT(df=df, loc=mu, scale=sigma) student = student_t.StudentT(df=df, loc=mu, scale=sigma)
# Test broadcast of mu across shape of df/sigma # Test broadcast of mu across shape of df/sigma
stddev = student.stddev().eval() stddev = student.stddev().eval()
mu *= len(df) mu *= len(df)
if not stats:
return
expected_stddev = [ expected_stddev = [
stats.t.std(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma) stats.t.std(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)
] ]
@ -390,14 +426,14 @@ class StudentTTest(test.TestCase):
df = [0.5, 1., 3] df = [0.5, 1., 3]
mu = [-1, 0., 1] mu = [-1, 0., 1]
sigma = [5., 4., 3.] sigma = [5., 4., 3.]
student = ds.StudentT(df=df, loc=mu, scale=sigma) student = student_t.StudentT(df=df, loc=mu, scale=sigma)
# Test broadcast of mu across shape of df/sigma # Test broadcast of mu across shape of df/sigma
mode = student.mode().eval() mode = student.mode().eval()
self.assertAllClose([-1., 0, 1], mode) self.assertAllClose([-1., 0, 1], mode)
def testPdfOfSample(self): def testPdfOfSample(self):
with self.test_session() as sess: with self.test_session() as sess:
student = ds.StudentT(df=3., loc=np.pi, scale=1.) student = student_t.StudentT(df=3., loc=np.pi, scale=1.)
num = 20000 num = 20000
samples = student.sample(num, seed=123456) samples = student.sample(num, seed=123456)
pdfs = student.prob(samples) pdfs = student.prob(samples)
@ -410,13 +446,15 @@ class StudentTTest(test.TestCase):
self.assertEqual(mean.get_shape(), ()) self.assertEqual(mean.get_shape(), ())
self.assertNear(np.pi, np.mean(sample_vals), err=0.02) self.assertNear(np.pi, np.mean(sample_vals), err=0.02)
self.assertNear(np.pi, mean_val, err=1e-6) self.assertNear(np.pi, mean_val, err=1e-6)
self.assertNear(stats.t.pdf(np.pi, 3., loc=np.pi), mean_pdf_val, err=1e-6)
# Verify integral over sample*pdf ~= 1. # Verify integral over sample*pdf ~= 1.
self._assertIntegral(sample_vals, pdf_vals, err=2e-3) self._assertIntegral(sample_vals, pdf_vals, err=2e-3)
if not stats:
return
self.assertNear(stats.t.pdf(np.pi, 3., loc=np.pi), mean_pdf_val, err=1e-6)
def testPdfOfSampleMultiDims(self): def testPdfOfSampleMultiDims(self):
with self.test_session() as sess: with self.test_session() as sess:
student = ds.StudentT(df=[7., 11.], loc=[[5.], [6.]], scale=3.) student = student_t.StudentT(df=[7., 11.], loc=[[5.], [6.]], scale=3.)
self.assertAllEqual([], student.event_shape) self.assertAllEqual([], student.event_shape)
self.assertAllEqual([], student.event_shape_tensor().eval()) self.assertAllEqual([], student.event_shape_tensor().eval())
self.assertAllEqual([2, 2], student.batch_shape) self.assertAllEqual([2, 2], student.batch_shape)
@ -429,6 +467,12 @@ class StudentTTest(test.TestCase):
self.assertEqual(pdfs.get_shape(), (num, 2, 2)) self.assertEqual(pdfs.get_shape(), (num, 2, 2))
self.assertNear(5., np.mean(sample_vals[:, 0, :]), err=.03) self.assertNear(5., np.mean(sample_vals[:, 0, :]), err=.03)
self.assertNear(6., np.mean(sample_vals[:, 1, :]), err=.03) self.assertNear(6., np.mean(sample_vals[:, 1, :]), err=.03)
self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02)
self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.02)
self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.02)
self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.02)
if not stats:
return
self.assertNear( self.assertNear(
stats.t.var(7., loc=0., scale=3.), # loc d.n. effect var stats.t.var(7., loc=0., scale=3.), # loc d.n. effect var
np.var(sample_vals[:, :, 0]), np.var(sample_vals[:, :, 0]),
@ -437,10 +481,6 @@ class StudentTTest(test.TestCase):
stats.t.var(11., loc=0., scale=3.), # loc d.n. effect var stats.t.var(11., loc=0., scale=3.), # loc d.n. effect var
np.var(sample_vals[:, :, 1]), np.var(sample_vals[:, :, 1]),
err=.4) err=.4)
self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02)
self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.02)
self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.02)
self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.02)
def _assertIntegral(self, sample_vals, pdf_vals, err=1.5e-3): def _assertIntegral(self, sample_vals, pdf_vals, err=1.5e-3):
s_p = zip(sample_vals, pdf_vals) s_p = zip(sample_vals, pdf_vals)
@ -454,8 +494,8 @@ class StudentTTest(test.TestCase):
def testNegativeDofFails(self): def testNegativeDofFails(self):
with self.test_session(): with self.test_session():
student = ds.StudentT(df=[2, -5.], loc=0., scale=1., student = student_t.StudentT(df=[2, -5.], loc=0., scale=1.,
validate_args=True, name="S") validate_args=True, name="S")
with self.assertRaisesOpError(r"Condition x > 0 did not hold"): with self.assertRaisesOpError(r"Condition x > 0 did not hold"):
student.mean().eval() student.mean().eval()
@ -464,7 +504,8 @@ class StudentTTest(test.TestCase):
df = constant_op.constant([-3.2, -4.6]) df = constant_op.constant([-3.2, -4.6])
mu = constant_op.constant([-4.2, 3.4]) mu = constant_op.constant([-4.2, 3.4])
sigma = constant_op.constant([-6.4, -8.8]) sigma = constant_op.constant([-6.4, -8.8])
student = ds.StudentTWithAbsDfSoftplusScale(df=df, loc=mu, scale=sigma) student = student_t.StudentTWithAbsDfSoftplusScale(
df=df, loc=mu, scale=sigma)
self.assertAllClose( self.assertAllClose(
math_ops.floor(math_ops.abs(df)).eval(), student.df.eval()) math_ops.floor(math_ops.abs(df)).eval(), student.df.eval())
self.assertAllClose(mu.eval(), student.loc.eval()) self.assertAllClose(mu.eval(), student.loc.eval())

View File

@ -18,15 +18,30 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import importlib
import numpy as np import numpy as np
from scipy import stats
from tensorflow.contrib.distributions.python.ops import uniform as uniform_lib
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors_impl from tensorflow.python.framework import errors
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import uniform as uniform_lib
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
def try_import(name): # pylint: disable=invalid-name
module = None
try:
module = importlib.import_module(name)
except ImportError as e:
tf_logging.warning("Could not import %s: %s" % (name, str(e)))
return module
stats = try_import("scipy.stats")
class UniformTest(test.TestCase): class UniformTest(test.TestCase):
@ -126,7 +141,7 @@ class UniformTest(test.TestCase):
b_v = np.array([1.0, 2.0, 3.0], dtype=np.float32) b_v = np.array([1.0, 2.0, 3.0], dtype=np.float32)
uniform = uniform_lib.Uniform(low=a_v, high=b_v, validate_args=True) uniform = uniform_lib.Uniform(low=a_v, high=b_v, validate_args=True)
with self.assertRaisesWithPredicateMatch(errors_impl.InvalidArgumentError, with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
"x < y"): "x < y"):
uniform.low.eval() uniform.low.eval()
@ -187,6 +202,8 @@ class UniformTest(test.TestCase):
a = 10.0 a = 10.0
b = 100.0 b = 100.0
uniform = uniform_lib.Uniform(low=a, high=b) uniform = uniform_lib.Uniform(low=a, high=b)
if not stats:
return
s_uniform = stats.uniform(loc=a, scale=b - a) s_uniform = stats.uniform(loc=a, scale=b - a)
self.assertAllClose(uniform.mean().eval(), s_uniform.mean()) self.assertAllClose(uniform.mean().eval(), s_uniform.mean())
@ -195,6 +212,8 @@ class UniformTest(test.TestCase):
a = 10.0 a = 10.0
b = 100.0 b = 100.0
uniform = uniform_lib.Uniform(low=a, high=b) uniform = uniform_lib.Uniform(low=a, high=b)
if not stats:
return
s_uniform = stats.uniform(loc=a, scale=b - a) s_uniform = stats.uniform(loc=a, scale=b - a)
self.assertAllClose(uniform.variance().eval(), s_uniform.var()) self.assertAllClose(uniform.variance().eval(), s_uniform.var())
@ -203,6 +222,8 @@ class UniformTest(test.TestCase):
a = 10.0 a = 10.0
b = 100.0 b = 100.0
uniform = uniform_lib.Uniform(low=a, high=b) uniform = uniform_lib.Uniform(low=a, high=b)
if not stats:
return
s_uniform = stats.uniform(loc=a, scale=b - a) s_uniform = stats.uniform(loc=a, scale=b - a)
self.assertAllClose(uniform.stddev().eval(), s_uniform.std()) self.assertAllClose(uniform.stddev().eval(), s_uniform.std())

View File

@ -24,6 +24,7 @@ py_library(
"//tensorflow/python:nn", "//tensorflow/python:nn",
"//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops",
"//tensorflow/python:platform", "//tensorflow/python:platform",
"//tensorflow/python:special_math_ops",
"//tensorflow/python:util", "//tensorflow/python:util",
], ],
) )

View File

@ -18,8 +18,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import distribution_util
from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import util as distribution_util
class ConditionalDistribution(distribution.Distribution): class ConditionalDistribution(distribution.Distribution):

View File

@ -20,13 +20,13 @@ from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.contrib.distributions.python.ops import gamma
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import gamma
__all__ = [ __all__ = [

View File

@ -70,14 +70,14 @@ class Normal(distribution.Distribution):
```python ```python
# Define a single scalar Normal distribution. # Define a single scalar Normal distribution.
dist = tf.contrib.distributions.Normal(loc=0., scale=3.) dist = tf.distributions.Normal(loc=0., scale=3.)
# Evaluate the cdf at 1, returning a scalar. # Evaluate the cdf at 1, returning a scalar.
dist.cdf(1.) dist.cdf(1.)
# Define a batch of two scalar valued Normals. # Define a batch of two scalar valued Normals.
# The first has mean 1 and standard deviation 11, the second 2 and 22. # The first has mean 1 and standard deviation 11, the second 2 and 22.
dist = tf.contrib.distributions.Normal(loc=[1, 2.], scale=[11, 22.]) dist = tf.distributions.Normal(loc=[1, 2.], scale=[11, 22.])
# Evaluate the pdf of the first distribution on 0, and the second on 1.5, # Evaluate the pdf of the first distribution on 0, and the second on 1.5,
# returning a length two tensor. # returning a length two tensor.
@ -92,7 +92,7 @@ class Normal(distribution.Distribution):
```python ```python
# Define a batch of two scalar valued Normals. # Define a batch of two scalar valued Normals.
# Both have mean 1, but different standard deviations. # Both have mean 1, but different standard deviations.
dist = tf.contrib.distributions.Normal(loc=1., scale=[11, 22.]) dist = tf.distributions.Normal(loc=1., scale=[11, 22.])
# Evaluate the pdf of both distributions on the same point, 3.0, # Evaluate the pdf of both distributions on the same point, 3.0,
# returning a length 2 tensor. # returning a length 2 tensor.

View File

@ -42,8 +42,10 @@ __all__ = [
class StudentT(distribution.Distribution): class StudentT(distribution.Distribution):
# pylint: disable=line-too-long """Student's t-distribution.
"""Student's t-distribution with degree of freedom `df`, location `loc`, and `scale` parameters.
This distribution has parameters: degree of freedom `df`, location `loc`,
and `scale`.
#### Mathematical details #### Mathematical details
@ -82,7 +84,7 @@ class StudentT(distribution.Distribution):
```python ```python
# Define a single scalar Student t distribution. # Define a single scalar Student t distribution.
single_dist = tf.contrib.distributions.StudentT(df=3) single_dist = tf.distributions.StudentT(df=3)
# Evaluate the pdf at 1, returning a scalar Tensor. # Evaluate the pdf at 1, returning a scalar Tensor.
single_dist.prob(1.) single_dist.prob(1.)
@ -90,7 +92,7 @@ class StudentT(distribution.Distribution):
# Define a batch of two scalar valued Student t's. # Define a batch of two scalar valued Student t's.
# The first has degrees of freedom 2, mean 1, and scale 11. # The first has degrees of freedom 2, mean 1, and scale 11.
# The second 3, 2 and 22. # The second 3, 2 and 22.
multi_dist = tf.contrib.distributions.StudentT(df=[2, 3], multi_dist = tf.distributions.StudentT(df=[2, 3],
loc=[1, 2.], loc=[1, 2.],
scale=[11, 22.]) scale=[11, 22.])
@ -107,7 +109,7 @@ class StudentT(distribution.Distribution):
```python ```python
# Define a batch of two Student's t distributions. # Define a batch of two Student's t distributions.
# Both have df 2 and mean 1, but different scales. # Both have df 2 and mean 1, but different scales.
dist = tf.contrib.distributions.StudentT(df=2, loc=1, scale=[11, 22.]) dist = tf.distributions.StudentT(df=2, loc=1, scale=[11, 22.])
# Evaluate the pdf of both distributions on the same point, 3.0, # Evaluate the pdf of both distributions on the same point, 3.0,
# returning a length 2 tensor. # returning a length 2 tensor.

View File

@ -57,7 +57,7 @@ def get_losses(scope=None, loss_collection=ops.GraphKeys.LOSSES):
def get_regularization_losses(scope=None): def get_regularization_losses(scope=None):
"""Gets the regularization losses. """Gets the list of regularization losses.
Args: Args:
scope: An optional scope for filtering the losses to return. scope: An optional scope for filtering the losses to return.
@ -88,7 +88,11 @@ def get_regularization_loss(scope=None, name="total_regularization_loss"):
def get_total_loss(add_regularization_losses=True, name="total_loss"): def get_total_loss(add_regularization_losses=True, name="total_loss"):
"""Returns a tensor whose value represents the total loss. """Returns a tensor whose value represents the total loss.
Notice that the function adds the given losses to the regularization losses. In particular, this adds any losses you have added with `tf.add_loss()` to
any regularization losses that have been added by regularization parameters
on layers constructors e.g. `tf.layers`. Be very sure to use this if you
are constructing a loss_op manually. Otherwise regularization arguments
on `tf.layers` methods will not function.
Args: Args:
add_regularization_losses: A boolean indicating whether or not to use the add_regularization_losses: A boolean indicating whether or not to use the

View File

@ -97,9 +97,10 @@ def assert_broadcastable(weights, values):
return control_flow_ops.no_op(name="static_scalar_check_success") return control_flow_ops.no_op(name="static_scalar_check_success")
if weights_rank_static != values_rank_static: if weights_rank_static != values_rank_static:
raise ValueError( raise ValueError(
"%s values.rank=%s. weights.rank=%s." % ( "%s values.rank=%s. weights.rank=%s."
" values.shape=%s. weights.shape=%s." % (
_ASSERT_BROADCASTABLE_ERROR_PREFIX, values_rank_static, _ASSERT_BROADCASTABLE_ERROR_PREFIX, values_rank_static,
weights_rank_static)) weights_rank_static, values.shape, weights.shape))
weights_shape_static = tensor_util.constant_value(weights_shape) weights_shape_static = tensor_util.constant_value(weights_shape)
values_shape_static = tensor_util.constant_value(values_shape) values_shape_static = tensor_util.constant_value(values_shape)
if weights_shape_static is not None and values_shape_static is not None: if weights_shape_static is not None and values_shape_static is not None:

View File

@ -27,6 +27,23 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import saver as saver_mod from tensorflow.python.training import saver as saver_mod
def _maybe_name(obj):
"""Returns object name if it has one, or a message otherwise.
This is useful for names that apper in error messages.
Args:
obj: Object to get the name of.
Returns:
name, "None", or a "no name" message.
"""
if obj is None:
return "None"
elif hasattr(obj, "name"):
return obj.name
else:
return "<no name for %s>" % type(obj)
class SessionManager(object): class SessionManager(object):
"""Training helper that restores from checkpoint and creates session. """Training helper that restores from checkpoint and creates session.
@ -267,8 +284,8 @@ class SessionManager(object):
if not local_init_success: if not local_init_success:
raise RuntimeError( raise RuntimeError(
"Init operations did not make model ready for local_init. " "Init operations did not make model ready for local_init. "
"Init op: %s, init fn: %s, error: %s" % ("None" if init_op is None "Init op: %s, init fn: %s, error: %s" % (_maybe_name(init_op),
else init_op.name, init_fn, init_fn,
msg)) msg))
is_ready, msg = self._model_ready(sess) is_ready, msg = self._model_ready(sess)
@ -276,8 +293,7 @@ class SessionManager(object):
raise RuntimeError( raise RuntimeError(
"Init operations did not make model ready. " "Init operations did not make model ready. "
"Init op: %s, init fn: %s, local_init_op: %s, error: %s" % "Init op: %s, init fn: %s, local_init_op: %s, error: %s" %
(None if init_op is None else init_op.name, init_fn, (_maybe_name(init_op), init_fn, self._local_init_op, msg))
self._local_init_op, msg))
return sess return sess
def recover_session(self, def recover_session(self,

View File

@ -497,6 +497,23 @@ class SessionManagerTest(test.TestCase):
"Init operations did not make model ready"): "Init operations did not make model ready"):
sm2.prepare_session("", init_op=v.initializer) sm2.prepare_session("", init_op=v.initializer)
def testPrepareSessionDidNotInitLocalVariableList(self):
with ops.Graph().as_default():
v = variables.Variable(1, name="v")
w = variables.Variable(
v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
name="w")
with self.test_session():
self.assertEqual(False, variables.is_variable_initialized(v).eval())
self.assertEqual(False, variables.is_variable_initialized(w).eval())
sm2 = session_manager.SessionManager(
ready_op=variables.report_uninitialized_variables())
with self.assertRaisesRegexp(RuntimeError,
"Init operations did not make model ready"):
sm2.prepare_session("", init_op=[v.initializer])
def testPrepareSessionWithReadyNotReadyForLocal(self): def testPrepareSessionWithReadyNotReadyForLocal(self):
with ops.Graph().as_default(): with ops.Graph().as_default():
v = variables.Variable(1, name="v") v = variables.Variable(1, name="v")

View File

@ -1185,7 +1185,7 @@ def tf_version_info_genrule():
], ],
outs=["util/version_info.cc"], outs=["util/version_info.cc"],
cmd= cmd=
"$(PYTHON_BIN_PATH) $(location //tensorflow/tools/git:gen_git_source.py) --generate $(SRCS) \"$@\"", "$(location //tensorflow/tools/git:gen_git_source.py) --generate $(SRCS) \"$@\"",
local=1, local=1,
tools=[clean_dep("//tensorflow/tools/git:gen_git_source.py")],) tools=[clean_dep("//tensorflow/tools/git:gen_git_source.py")],)

Some files were not shown because too many files have changed in this diff Show More