Merge pull request #14630 from jhseu/branch_175983704

Branch 175983704
This commit is contained in:
Jonathan Hseu 2017-11-16 12:33:27 -08:00 committed by GitHub
commit 9089ab5982
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
343 changed files with 12984 additions and 5642 deletions

View File

@ -375,6 +375,7 @@ config_setting(
package_group(
name = "internal",
packages = [
"//learning/meta_rank/...",
"//tensorflow/...",
"//tensorflow_fold/llgtm/...",
],

View File

@ -39,21 +39,23 @@ static void AllocateFlags() {
flags->tf_xla_min_cluster_size = 2;
flags->tf_xla_max_cluster_size = std::numeric_limits<int32>::max();
flags->tf_xla_clustering_debug = false;
flag_list = new std::vector<Flag>({
Flag("tf_xla_auto_jit", &flags->tf_xla_auto_jit,
"Control compilation of operators into XLA computations on CPU and "
"GPU devices. 0 = use ConfigProto setting; -1 = off; 1 = on for "
"things very likely to be improved; 2 = on for everything. "
"Experimental."),
Flag("tf_xla_min_cluster_size", &flags->tf_xla_min_cluster_size,
"Minimum number of operators in an XLA compilation. Ignored for "
"operators placed on an XLA device or operators explicitly marked "
"for compilation."),
Flag("tf_xla_max_cluster_size", &flags->tf_xla_max_cluster_size,
"Maximum number of operators in an XLA compilation."),
Flag("tf_xla_clustering_debug", &flags->tf_xla_clustering_debug,
"Dump graphs during XLA compilation."),
});
flags->tf_xla_cpu_global_jit = false;
flag_list = new std::vector<Flag>(
{Flag("tf_xla_auto_jit", &flags->tf_xla_auto_jit,
"Control compilation of operators into XLA computations on CPU and "
"GPU devices. 0 = use ConfigProto setting; -1 = off; 1 = on for "
"things very likely to be improved; 2 = on for everything. "
"Experimental."),
Flag("tf_xla_min_cluster_size", &flags->tf_xla_min_cluster_size,
"Minimum number of operators in an XLA compilation. Ignored for "
"operators placed on an XLA device or operators explicitly marked "
"for compilation."),
Flag("tf_xla_max_cluster_size", &flags->tf_xla_max_cluster_size,
"Maximum number of operators in an XLA compilation."),
Flag("tf_xla_clustering_debug", &flags->tf_xla_clustering_debug,
"Dump graphs during XLA compilation."),
Flag("tf_xla_cpu_global_jit", &flags->tf_xla_cpu_global_jit,
"Enables global JIT compilation for CPU via SessionOptions.")});
xla::legacy_flags::ParseFlagsFromEnv(*flag_list);
}

View File

@ -46,6 +46,8 @@ typedef struct {
int32 tf_xla_max_cluster_size; // Maximum number of operators in an XLA
// compilation.
bool tf_xla_clustering_debug; // Dump graphs during XLA compilation.
bool tf_xla_cpu_global_jit; // Enables global JIT compilation for CPU
// via SessionOptions.
} MarkForCompilationPassFlags;
// Return a pointer to the MarkForCompilationPassFlags struct;

View File

@ -290,9 +290,11 @@ Status MarkForCompilationPass::Run(
global_jit_level =
static_cast<OptimizerOptions::GlobalJitLevel>(flags->tf_xla_auto_jit);
}
bool cpu_global_jit = flags->tf_xla_cpu_global_jit;
const FunctionLibraryDefinition* fld = options.flib_def;
auto is_compilable = [global_jit_level, fld](const Node* node,
const DeviceType& device_type) {
auto is_compilable = [global_jit_level, cpu_global_jit, fld](
const Node* node, const DeviceType& device_type) {
const XlaOpRegistry::DeviceRegistration* registration;
if (!XlaOpRegistry::GetCompilationDevice(device_type.type(),
&registration)) {
@ -315,7 +317,11 @@ Status MarkForCompilationPass::Run(
if (status.ok()) return compile;
// Otherwise use the value of global_jit_level.
return registration->enable_jit_by_default && global_jit_level > 0;
// Ignore enable_jit_by_default if global jit compilation for CPU
// is explicitly requested via tf_xla_cpu_global_jit flag
bool ignore_registration = cpu_global_jit && device_type == DEVICE_CPU;
return (ignore_registration || registration->enable_jit_by_default) &&
global_jit_level > 0;
};
return RunImpl(options, is_compilable);
}
@ -556,6 +562,7 @@ Status MarkForCompilationPass::RunImpl(
if (cluster_sizes[cluster] >= min_cluster_size || marked_for_compilation ||
registration->requires_compilation) {
string& name = cluster_names[cluster];
if (name.empty()) {
name = strings::StrCat("cluster_", cluster_sequence_num++);
}

View File

@ -179,7 +179,6 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",

View File

@ -731,11 +731,12 @@ string DebugString(const Graph& graph,
FunctionalizeCond::ClusterHandle::Vector* clusters) {
string ret = "digraph {\ncompound=true;labeljust=\"r\";ranksep=0.24\n";
std::map<FunctionalizeCond::ClusterHandle, string> subgraphs;
auto name = [](const Node* n) {
return strings::StrCat(n->type_string(), "_", n->id());
};
for (Node* n : graph.nodes()) {
if (n->IsOp()) {
strings::StrAppend(&subgraphs[clusters->at(n).Get()], n->id(),
" [label=\"", n->name(), "\"];\n");
}
strings::StrAppend(&subgraphs[clusters->at(n).Get()], n->id(), " [label=\"",
name(n), "\"];\n");
}
for (auto kv : subgraphs) {
strings::StrAppend(&ret, "subgraph cluster_", kv.first.ToString(), " {\n",
@ -743,16 +744,11 @@ string DebugString(const Graph& graph,
kv.first.ToString(), "\";\n", kv.second, "}\n");
}
for (Node* n : graph.nodes()) {
if (!n->IsOp()) {
continue;
}
for (Node* in : n->in_nodes()) {
if (in->IsOp()) {
strings::StrAppend(&ret, in->id(), " -> ", n->id(), ";\n");
}
strings::StrAppend(&ret, in->id(), " -> ", n->id(), ";\n");
}
}
return strings::StrCat(ret, "}");
return strings::StrCat(ret, "} // end");
}
string DebugString(const FunctionalizeCond::ClusteredGraph& clustered_graph) {
@ -761,16 +757,24 @@ string DebugString(const FunctionalizeCond::ClusteredGraph& clustered_graph) {
return cluster.representative.ToString();
};
for (auto kv : clustered_graph) {
strings::StrAppend(&ret, kv.first.ToString(), " [label=\"", name(kv.second),
" (", kv.second.switch_nodes.size(), ", ",
kv.second.merge_nodes.size(), ")\"];\n");
if (!kv.second.switch_nodes.empty() || !kv.second.merge_nodes.empty()) {
strings::StrAppend(
&ret, kv.first.ToString(), " [label=\"", name(kv.second),
kv.second.switch_nodes.empty()
? ""
: strings::StrCat(" switches=", kv.second.switch_nodes.size()),
kv.second.merge_nodes.empty()
? ""
: strings::StrCat(" merges=", kv.second.merge_nodes.size()),
"\"];\n");
}
}
for (auto kv : clustered_graph) {
for (auto in : kv.second.in_nodes) {
strings::StrAppend(&ret, name(*in), " -> ", name(kv.second), ";\n");
}
}
return strings::StrCat(ret, "}");
return strings::StrCat(ret, "} // end");
}
bool IsDeadSwitch(const Node* node) {
@ -790,9 +794,6 @@ bool IsDeadSwitch(const Node* node) {
void FunctionalizeCond::CreateClusters() {
for (Node* node : graph_->nodes()) {
if (!node->IsOp()) {
continue;
}
if (IsSwitch(node)) {
switch_nodes_.insert(node);
} else if (IsMerge(node)) {
@ -825,6 +826,10 @@ void FunctionalizeCond::CreateClusters() {
clusters_.at(node).Merge(&clusters_.at(in));
}
}
// Group all source clusters together.
if (node->IsSource() || node->in_edges().empty()) {
clusters_.at(node).Merge(&clusters_.at(ClusterHandle(Graph::kSourceId)));
}
}
}
@ -876,7 +881,7 @@ void FunctionalizeCond::CreateClusteredGraph() {
for (const Node* in : node->in_nodes()) {
ClusterHandle other_repr = Representative(in);
// Skip source, sink and internal edges.
if (!in->IsOp() || other_repr == repr) {
if (other_repr == repr) {
continue;
}
Cluster& cluster_node_in = clustered_graph_[other_repr];
@ -887,7 +892,7 @@ void FunctionalizeCond::CreateClusteredGraph() {
for (const Node* out : node->out_nodes()) {
ClusterHandle other_repr = Representative(out);
// Skip source, sink and internal edges.
if (!out->IsOp() || other_repr == repr) {
if (other_repr == repr) {
continue;
}
Cluster& cluster_node_out = clustered_graph_[other_repr];
@ -897,6 +902,7 @@ void FunctionalizeCond::CreateClusteredGraph() {
}
return cluster_node;
};
update_cluster_for_node(graph_->source_node());
for (Node* node : switch_nodes_) {
update_cluster_for_node(node).switch_nodes.insert(node);
}
@ -955,7 +961,7 @@ gtl::optional<FunctionalizeCond::Cluster*> FunctionalizeCond::GetSwitchCluster(
for (Cluster* in : merge_cluster.in_nodes) {
Cluster* cluster = in;
if (in->switch_nodes.empty()) {
if (in->in_nodes.size() != 1) {
if (in->in_nodes.size() != 1 || in->out_nodes.size() != 1) {
return gtl::nullopt;
}
// There is only a single `in` cluster.
@ -1292,11 +1298,8 @@ std::vector<std::pair<int, FunctionalizeCond::Cluster*>>
FunctionalizeCond::SortedMergeNodes() {
VLOG(2) << "ProcessClusteredGraph";
std::stack<std::pair<int, Cluster*>> stack;
for (auto& c : clustered_graph_) {
if (c.second.in_nodes.empty()) {
stack.push({0, &c.second});
}
}
// Initialize with the source node.
stack.push({0, &clustered_graph_[ClusterHandle(Graph::kSourceId)]});
// Perform a depth-first traversal of the clustered graph computing the
// switch-merge depth.

View File

@ -40,6 +40,11 @@ class ConstOp : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override {
TensorShape shape(proto_.tensor_shape());
if (proto_.dtype() == DT_STRING) {
LOG(WARNING) << "Not computing Const of type DT_STRING";
ctx->SetInvalidOutput(0);
return;
}
xla::ComputationBuilder* b = ctx->builder();
// To avoid blowups for large constants filled with the same value,

View File

@ -345,6 +345,16 @@ void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) {
expression->set_constant_value(constant);
}
void XlaOpKernelContext::SetInvalidOutput(int index) {
const TensorShape shape;
Tensor* output = nullptr;
OP_REQUIRES_OK(context_, context_->allocate_output(index, shape, &output));
XlaExpression* expression = CastExpressionFromUninitializedTensor(output);
xla::ComputationDataHandle handle;
handle.set_handle(0);
expression->set_handle(handle);
}
void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) {
Tensor* output = nullptr;
// The shape of the output tensor is the shape of the resource itself

View File

@ -142,6 +142,10 @@ class XlaOpKernelContext {
// SetConstantOutput where possible.
void SetConstantOutput(int index, const Tensor& host_tensor);
// Sets output 'index' to an invalid value.
// Any subsequent attempt to consume this output will cause an error.
void SetInvalidOutput(int index);
// Status handling.
void SetStatus(const Status& status) { context_->SetStatus(status); }
Status status() { return context_->status(); }

View File

@ -265,6 +265,42 @@ bool BufferAssignment::SharesSliceAtIndex(
GetUniqueSlice(hlo_b, shape_index_b).ConsumeValueOrDie();
}
bool BufferAssignment::HaveDisjointSlices(const HloInstruction* hlo_a,
const HloInstruction* hlo_b) const {
using SliceSet =
FlatSet<BufferAllocation::Slice, BufferAllocation::Slice::Hasher>;
// Gets the slices all of instr's subshapes. If any subshape doesn't have an
// assigned slice, returns the empty set.
auto collect_slices = [&](const HloInstruction* instr) -> SliceSet {
SliceSet slices;
Status status = ShapeUtil::ForEachSubshapeWithStatus(
instr->shape(),
[&](const Shape& /*subshape*/, const ShapeIndex& index) {
auto shape_slices = GetAllSlices(instr, index);
if (shape_slices.empty()) {
return InvalidArgument("No slices assigned to part of instr.");
}
slices.insert(shape_slices.begin(), shape_slices.end());
return Status::OK();
});
if (!status.ok()) {
return {};
}
return slices;
};
SliceSet slices_a = collect_slices(hlo_a);
SliceSet slices_b = collect_slices(hlo_b);
// hlo_a and hlo_b have disjoint slices if collect_slices succeeded (i.e.
// didn't return the empty set) for both HLOs, and the two resulting sets of
// slices are disjoint.
return !slices_a.empty() && !slices_b.empty() &&
std::none_of(slices_a.begin(), slices_a.end(),
[&](const BufferAllocation::Slice& slice) {
return slices_b.count(slice) > 0;
});
}
StatusOr<BufferAllocation::Slice>
BufferAssignment::GetUniqueTopLevelOutputSlice() const {
return GetUniqueTopLevelSlice(

View File

@ -327,6 +327,12 @@ class BufferAssignment {
return SharesSliceAtIndex(hlo_a, {}, hlo_b, {});
}
// Returns true if hlo_a and hlo_b both have at least one buffer assigned for
// their top-level and each of their nested shape indices, and if hlo_a's
// buffers are all different from hlo_b's buffers.
bool HaveDisjointSlices(const HloInstruction* hlo_a,
const HloInstruction* hlo_b) const;
// Returns the underlying points-to analysis used for this assignment.
const TuplePointsToAnalysis& points_to_analysis() const {
return liveness_->points_to_analysis();

View File

@ -27,14 +27,8 @@ namespace se = ::perftools::gputools;
namespace xla {
/* static */ tensorflow::mutex* Compiler::platform_compiler_mutex_;
/* static */ void Compiler::LazyInitMutex() {
static std::once_flag mutex_init_flag;
std::call_once(mutex_init_flag, []() {
Compiler::platform_compiler_mutex_ = new tensorflow::mutex;
});
}
/* static */ tensorflow::mutex Compiler::platform_compiler_mutex_(
tensorflow::LINKER_INITIALIZED);
/* static */ std::map<perftools::gputools::Platform::Id,
Compiler::CompilerFactory>*
@ -55,8 +49,7 @@ Compiler::GetPlatformCompilers() {
/* static */ void Compiler::RegisterCompilerFactory(
se::Platform::Id platform_id,
std::function<std::unique_ptr<Compiler>()> compiler_factory) {
LazyInitMutex();
tensorflow::mutex_lock lock(*platform_compiler_mutex_);
tensorflow::mutex_lock lock(platform_compiler_mutex_);
auto* factories = GetPlatformCompilerFactories();
CHECK(factories->find(platform_id) == factories->end())
<< "Compiler factory already registered for platform";
@ -65,8 +58,7 @@ Compiler::GetPlatformCompilers() {
/* static */ StatusOr<Compiler*> Compiler::GetForPlatform(
const se::Platform* platform) {
LazyInitMutex();
tensorflow::mutex_lock lock(*platform_compiler_mutex_);
tensorflow::mutex_lock lock(platform_compiler_mutex_);
auto* compilers = GetPlatformCompilers();
// See if we already instantiated a compiler for this platform.

View File

@ -157,8 +157,7 @@ class Compiler {
private:
// Mutex that guards the platform-compiler map.
static tensorflow::mutex* platform_compiler_mutex_;
static void LazyInitMutex();
static tensorflow::mutex platform_compiler_mutex_;
// Map from platform kind to compiler factory.
static std::map<perftools::gputools::Platform::Id, CompilerFactory>*

View File

@ -94,7 +94,7 @@ StatusOr<DeviceAssignment> ComputationPlacer::AssignDevices(
se::Platform::Id platform_id,
ComputationPlacerCreationFunction creation_function) {
tensorflow::mutex_lock lock(
*ComputationPlacer::platform_computation_placer_mutex());
ComputationPlacer::platform_computation_placer_mutex_);
auto* computation_placers = GetPlatformComputationPlacers();
CHECK(computation_placers->find(platform_id) == computation_placers->end());
(*computation_placers)[platform_id].creation_function = creation_function;
@ -103,7 +103,7 @@ StatusOr<DeviceAssignment> ComputationPlacer::AssignDevices(
/* static */ StatusOr<ComputationPlacer*> ComputationPlacer::GetForPlatform(
const se::Platform* platform) {
tensorflow::mutex_lock lock(
*ComputationPlacer::platform_computation_placer_mutex());
ComputationPlacer::platform_computation_placer_mutex_);
auto* computation_placers = GetPlatformComputationPlacers();
auto it = computation_placers->find(platform->id());
@ -122,11 +122,9 @@ StatusOr<DeviceAssignment> ComputationPlacer::AssignDevices(
return it->second.placer.get();
}
/* static */ tensorflow::mutex*
ComputationPlacer::platform_computation_placer_mutex() {
static tensorflow::mutex* m = new tensorflow::mutex;
return m;
}
/* static */ tensorflow::mutex
ComputationPlacer::platform_computation_placer_mutex_(
tensorflow::LINKER_INITIALIZED);
/* static */ std::map<perftools::gputools::Platform::Id,
ComputationPlacer::State>*

View File

@ -89,11 +89,8 @@ class ComputationPlacer {
const perftools::gputools::Platform* platform);
private:
// Routine that returns the mutex that guards the platform-to-computation
// placer map. Done as a routine to ensure correct initialization ordering,
// since RegisterComputationPlacer can be called during program initialization
// time.
static tensorflow::mutex* platform_computation_placer_mutex();
// The mutex that guards the platform-to-computation placer map.
static tensorflow::mutex platform_computation_placer_mutex_;
// State kept for each kind of ComputationPlacer. Registration functions set
// up creation_function, and then we use that to lazily create "placer" the

View File

@ -17,6 +17,7 @@ package_group(
load(":build_defs.bzl", "runtime_copts")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
load("//tensorflow/compiler/xla:xla.bzl", "ORC_JIT_MEMORY_MAPPER_TARGETS")
# Filegroup used to collect source files for dependency checking.
filegroup(
@ -83,6 +84,7 @@ cc_library(
":cpu_options",
":cpu_parallelization_preparation",
":disassembler",
":dot_op_emitter",
":ir_emission_utils",
":ir_emitter",
":layout_assignment",
@ -156,21 +158,23 @@ cc_library(
":custom_call_target_registry",
":disassembler",
":external_constant_pool",
":orc_jit_memory_mapper",
":runtime_conv2d",
":runtime_fork_join",
":runtime_matmul",
":runtime_single_threaded_conv2d",
":runtime_single_threaded_matmul",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"@llvm//:core",
"@llvm//:execution_engine",
"@llvm//:core",
"@llvm//:mc", # fixdeps: keep
"@llvm//:orc_jit",
"@llvm//:support",
"@llvm//:target", # fixdeps: keep
],
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
] + ORC_JIT_MEMORY_MAPPER_TARGETS,
)
cc_library(
@ -282,7 +286,6 @@ cc_library(
deps = [
":cpu_options",
":cpu_runtime",
":ir_emission_utils",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:types",
@ -619,6 +622,7 @@ cc_library(
srcs = ["layout_assignment.cc"],
hdrs = ["layout_assignment.h"],
deps = [
":dot_op_emitter",
":ir_emission_utils",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:computation_layout",
@ -706,6 +710,7 @@ cc_library(
srcs = ["parallel_task_assignment.cc"],
hdrs = ["parallel_task_assignment.h"],
deps = [
":dot_op_emitter",
":ir_emission_utils",
":shape_partition",
"//tensorflow/compiler/xla/service:hlo",
@ -735,6 +740,16 @@ cc_library(
visibility = ["//visibility:public"],
)
cc_library(
name = "orc_jit_memory_mapper",
srcs = ["orc_jit_memory_mapper.cc"],
hdrs = ["orc_jit_memory_mapper.h"],
deps = [
"//tensorflow/core:lib",
"@llvm//:execution_engine",
],
)
# -----------------------------------------------------------------------------
filegroup(

View File

@ -54,6 +54,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h"
#include "tensorflow/compiler/xla/service/cpu/disassembler.h"
#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/cpu/ir_emitter.h"
#include "tensorflow/compiler/xla/service/cpu/layout_assignment.h"

View File

@ -23,7 +23,6 @@ limitations under the License.
#include "llvm/IR/Module.h"
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
@ -950,5 +949,119 @@ llvm_ir::IrArray::Index DotOpEmitter::EmitOperandArrayLoopNest(
return index;
}
// Return whether the given shape is a matrix with no padding.
static bool IsRank2WithNoPadding(const Shape& shape) {
return ShapeUtil::Rank(shape) == 2 && !LayoutUtil::IsPadded(shape);
}
// In a gemm operation where output = lhs * rhs, check whether the given shapes
// are valid for the operation.
static bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
const Shape& output_shape) {
// The inputs and the output must
// 1) be matrices with no padding, and
// 2) have an allowed element type.
return output_shape.element_type() == F32 &&
IsRank2WithNoPadding(lhs_shape) && IsRank2WithNoPadding(rhs_shape) &&
IsRank2WithNoPadding(output_shape);
}
bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) {
// For certain types of Dot, we can call Eigen
if (hlo.opcode() == HloOpcode::kDot) {
const Shape& lhs_shape = hlo.operand(0)->shape();
const Shape& rhs_shape = hlo.operand(1)->shape();
if (ShapeUtil::HasZeroElements(lhs_shape) ||
ShapeUtil::HasZeroElements(rhs_shape)) {
return false;
}
if (ProfitableToImplementDotInUntiledLlvmIr(hlo) ==
DotInLlvmIrProfitable::kYes ||
ProfitableToImplementDotInTiledLlvmIr(hlo)) {
return false;
}
// If gemm can accept the operand shapes, use it rather than a custom
// kernel.
if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape())) {
// The size of the reduction dimension should match. The shape inference
// guarantees this invariant, so the check here is for programming
// errors.
CHECK_EQ(lhs_shape.dimensions(1), rhs_shape.dimensions(0));
return true;
}
}
if (hlo.opcode() == HloOpcode::kFusion &&
hlo.fusion_kind() == HloInstruction::FusionKind::kTransposeDot &&
hlo.fused_expression_root()->opcode() == HloOpcode::kDot) {
auto* dot = hlo.fused_expression_root();
const Shape& lhs_shape = dot->operand(0)->shape();
const Shape& rhs_shape = dot->operand(1)->shape();
if (ShapeUtil::HasZeroElements(lhs_shape) ||
ShapeUtil::HasZeroElements(rhs_shape)) {
return false;
}
return true;
}
return false;
}
DotInLlvmIrProfitable ProfitableToImplementDotInUntiledLlvmIr(
const HloInstruction& dot) {
if (dot.opcode() == HloOpcode::kDot && dot.shape().dimensions_size() == 2) {
const Shape& result_shape = dot.shape();
// kReductionDimensionThresholdBytes was chosen to be 1/4 of a typical L1
// cache line size, so that we can have the reduction dimension of both the
// LHS and RHS matrices and still have some space "left over". This needs
// to be tuned further.
const int64 kReductionDimensionThresholdBytes = 8 * 1024;
const bool single_threaded_eigen =
!dot.GetModule()->config().debug_options().xla_cpu_multi_thread_eigen();
// This is the point at which it is better to call into Eigen and shard the
// dot across multiple worker threads. This is a rough estimate by running
// a matmult benchmark on my local machine, and it can be tuned further.
const int64 kMaxSingleThreadedFlops = 16 * 1024;
const int64 M = result_shape.dimensions(0);
const int64 N = result_shape.dimensions(1);
const int64 K = dot.operand(1)->shape().dimensions(0);
const int64 primitive_type_size =
ShapeUtil::ByteSizeOfPrimitiveType(result_shape.element_type());
if (M == 1 &&
K * primitive_type_size <= kReductionDimensionThresholdBytes &&
(single_threaded_eigen || M * K * N <= kMaxSingleThreadedFlops)) {
// Heuristics:
//
// - Look for a configuration where we will likely be able to keep LHS in
// L1 and do a cache-optimal traversal of RHS.
//
// - Bail out on matrices that are large enough that Eigen can profitably
// shard the computation across multiple cores. This only applies when
// multi-threading is enabled.
return LayoutUtil::IsMonotonicWithDim0Major(
dot.operand(1)->shape().layout())
? DotInLlvmIrProfitable::kWithColumnMajorRhs
: DotInLlvmIrProfitable::kYes;
}
}
return DotInLlvmIrProfitable::kNo;
}
bool ProfitableToImplementDotInTiledLlvmIr(const HloInstruction& dot) {
// Any Matrix-Vector product of floating point or integral type, or
// a transpose-dot fusion of the same can be lowered to a tiled LLVM
// IR implementation.
const Shape& shape = dot.shape();
return shape.dimensions_size() == 2 &&
(shape.dimensions(0) == 1 || shape.dimensions(1) == 1) &&
(primitive_util::IsFloatingPointType(shape.element_type()) ||
primitive_util::IsIntegralType(shape.element_type()));
}
} // namespace cpu
} // namespace xla

View File

@ -30,6 +30,26 @@ limitations under the License.
namespace xla {
namespace cpu {
bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo);
enum class DotInLlvmIrProfitable { kYes, kNo, kWithColumnMajorRhs };
// Returns a value to indicate if (and under what conditions) will lowering
// |dot| as a untiled LLVM IR dot operation be profitable over calling into
// Eigen or emitting a tiled LLVM IR implementation. Possible return values
// are:
//
// * DotInLlvmIrProfitable::kYes - always profitable.
// * DotInLlvmIrProfitable::kNo - never profitable.
// * DotInLlvmIrProfitable::kWithColumnMajorRhs - only if we can manage to make
// the Rhs layout column major.
DotInLlvmIrProfitable ProfitableToImplementDotInUntiledLlvmIr(
const HloInstruction& dot);
// Returns true to indicate that we can generate a tiled LLVM IR implementation
// for |dot|.
bool ProfitableToImplementDotInTiledLlvmIr(const HloInstruction& dot);
// Helper class for emitting LLVM IR to perform the dot operation.
class DotOpEmitter {
public:

View File

@ -74,122 +74,5 @@ bool PotentiallyImplementedAsEigenConvolution(
kernel_shape.dimensions_size() - 1;
}
namespace {
// Return whether the given shape is a matrix with no padding.
bool IsRank2WithNoPadding(const Shape& shape) {
return ShapeUtil::Rank(shape) == 2 && !LayoutUtil::IsPadded(shape);
}
// In a gemm operation where output = lhs * rhs, check whether the given shapes
// are valid for the operation.
bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
const Shape& output_shape) {
// The inputs and the output must
// 1) be matrices with no padding, and
// 2) have an allowed element type.
return output_shape.element_type() == F32 &&
IsRank2WithNoPadding(lhs_shape) && IsRank2WithNoPadding(rhs_shape) &&
IsRank2WithNoPadding(output_shape);
}
} // namespace
bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) {
// For certain types of Dot, we can call Eigen
if (hlo.opcode() == HloOpcode::kDot) {
const Shape& lhs_shape = hlo.operand(0)->shape();
const Shape& rhs_shape = hlo.operand(1)->shape();
if (ShapeUtil::HasZeroElements(lhs_shape) ||
ShapeUtil::HasZeroElements(rhs_shape)) {
return false;
}
if (ProfitableToImplementDotInUntiledLlvmIr(hlo) ==
DotInLlvmIrProfitable::kYes ||
ProfitableToImplementDotInTiledLlvmIr(hlo)) {
return false;
}
// If gemm can accept the operand shapes, use it rather than a custom
// kernel.
if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape())) {
// The size of the reduction dimension should match. The shape inference
// guarantees this invariant, so the check here is for programming
// errors.
CHECK_EQ(lhs_shape.dimensions(1), rhs_shape.dimensions(0));
return true;
}
}
if (hlo.opcode() == HloOpcode::kFusion &&
hlo.fusion_kind() == HloInstruction::FusionKind::kTransposeDot &&
hlo.fused_expression_root()->opcode() == HloOpcode::kDot) {
auto* dot = hlo.fused_expression_root();
const Shape& lhs_shape = dot->operand(0)->shape();
const Shape& rhs_shape = dot->operand(1)->shape();
if (ShapeUtil::HasZeroElements(lhs_shape) ||
ShapeUtil::HasZeroElements(rhs_shape)) {
return false;
}
return true;
}
return false;
}
DotInLlvmIrProfitable ProfitableToImplementDotInUntiledLlvmIr(
const HloInstruction& dot) {
if (dot.opcode() == HloOpcode::kDot && dot.shape().dimensions_size() == 2) {
const Shape& result_shape = dot.shape();
// kReductionDimensionThresholdBytes was chosen to be 1/4 of a typical L1
// cache line size, so that we can have the reduction dimension of both the
// LHS and RHS matrices and still have some space "left over". This needs
// to be tuned further.
const int64 kReductionDimensionThresholdBytes = 8 * 1024;
const bool single_threaded_eigen =
!dot.GetModule()->config().debug_options().xla_cpu_multi_thread_eigen();
// This is the point at which it is better to call into Eigen and shard the
// dot across multiple worker threads. This is a rough estimate by running
// a matmult benchmark on my local machine, and it can be tuned further.
const int64 kMaxSingleThreadedFlops = 16 * 1024;
const int64 M = result_shape.dimensions(0);
const int64 N = result_shape.dimensions(1);
const int64 K = dot.operand(1)->shape().dimensions(0);
const int64 primitive_type_size =
ShapeUtil::ByteSizeOfPrimitiveType(result_shape.element_type());
if (M == 1 &&
K * primitive_type_size <= kReductionDimensionThresholdBytes &&
(single_threaded_eigen || M * K * N <= kMaxSingleThreadedFlops)) {
// Heuristics:
//
// - Look for a configuration where we will likely be able to keep LHS in
// L1 and do a cache-optimal traversal of RHS.
//
// - Bail out on matrices that are large enough that Eigen can profitably
// shard the computation across multiple cores. This only applies when
// multi-threading is enabled.
return LayoutUtil::IsMonotonicWithDim0Major(
dot.operand(1)->shape().layout())
? DotInLlvmIrProfitable::kWithColumnMajorRhs
: DotInLlvmIrProfitable::kYes;
}
}
return DotInLlvmIrProfitable::kNo;
}
bool ProfitableToImplementDotInTiledLlvmIr(const HloInstruction& dot) {
// Any Matrix-Vector product of floating point or integral type, or
// a transpose-dot fusion of the same can be lowered to a tiled LLVM
// IR implementation.
const Shape& shape = dot.shape();
return shape.dimensions_size() == 2 &&
(shape.dimensions(0) == 1 || shape.dimensions(1) == 1) &&
(primitive_util::IsFloatingPointType(shape.element_type()) ||
primitive_util::IsIntegralType(shape.element_type()));
}
} // namespace cpu
} // namespace xla

View File

@ -23,27 +23,6 @@ namespace cpu {
bool PotentiallyImplementedAsEigenConvolution(
const HloInstruction& convolution);
bool PotentiallyImplementedAsEigenDot(const HloInstruction& dot);
enum class DotInLlvmIrProfitable { kYes, kNo, kWithColumnMajorRhs };
// Returns a value to indicate if (and under what conditions) will lowering
// |dot| as a untiled LLVM IR dot operation be profitable over calling into
// Eigen or emitting a tiled LLVM IR implementation. Possible return values
// are:
//
// * DotInLlvmIrProfitable::kYes - always profitable.
// * DotInLlvmIrProfitable::kNo - never profitable.
// * DotInLlvmIrProfitable::kWithColumnMajorRhs - only if we can manage to make
// the Rhs layout column major.
DotInLlvmIrProfitable ProfitableToImplementDotInUntiledLlvmIr(
const HloInstruction& dot);
// Returns true to indicate that we can generate a tiled LLVM IR implementation
// for |dot|.
bool ProfitableToImplementDotInTiledLlvmIr(const HloInstruction& dot);
} // namespace cpu
} // namespace xla

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <numeric>
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
#include "tensorflow/core/lib/core/errors.h"

View File

@ -0,0 +1,40 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
namespace xla {
namespace cpu {
namespace orc_jit_memory_mapper {
static tensorflow::mutex mapper_instance_mutex(tensorflow::LINKER_INITIALIZED);
static llvm::SectionMemoryManager::MemoryMapper* mapper_instance
GUARDED_BY(mapper_instance_mutex) = nullptr;
llvm::SectionMemoryManager::MemoryMapper* GetInstance() {
tensorflow::mutex_lock lock(mapper_instance_mutex);
return mapper_instance;
}
Registrar::Registrar(
std::unique_ptr<llvm::SectionMemoryManager::MemoryMapper> mapper) {
tensorflow::mutex_lock lock(mapper_instance_mutex);
mapper_instance = mapper.release();
}
} // namespace orc_jit_memory_mapper
} // namespace cpu
} // namespace xla

View File

@ -0,0 +1,56 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_ORC_JIT_MEMORY_MAPPER_H_
#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_ORC_JIT_MEMORY_MAPPER_H_
#include <memory>
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
namespace xla {
namespace cpu {
namespace orc_jit_memory_mapper {
// Returns the registered memory mapper if there is one. Returns nullptr if no
// memory mapper is registered.
llvm::SectionMemoryManager::MemoryMapper* GetInstance();
class Registrar {
public:
// Registers the `mapper` as a memory mapper. This is a no-op if `mapper` is
// null. Precondition: no other memory mapper has been registered yet.
explicit Registrar(
std::unique_ptr<llvm::SectionMemoryManager::MemoryMapper> mapper);
};
} // namespace orc_jit_memory_mapper
#define XLA_INTERNAL_REGISTER_ORC_JIT_MEMORY_MAPPER(mapper_instance, ctr) \
static ::xla::cpu::orc_jit_memory_mapper::Registrar \
XLA_INTERNAL_REGISTER_ORC_JIT_MEMORY_MAPPER_NAME(ctr)(mapper_instance)
// __COUNTER__ must go through another macro to be properly expanded
#define XLA_INTERNAL_REGISTER_ORC_JIT_MEMORY_MAPPER_NAME(ctr) \
__orc_jit_memory_mapper_registrar_##ctr
// Registers the std::unique_ptr<llvm::SectionMemoryManager::MemoryMapper>
// returned by the `factory` expression. `factory` is allowed to evaluate to
// a null unique_ptr in which case this macro does nothing.
#define XLA_REGISTER_ORC_JIT_MEMORY_MAPPER(factory) \
XLA_INTERNAL_REGISTER_ORC_JIT_MEMORY_MAPPER(factory, __COUNTER__)
} // namespace cpu
} // namespace xla
#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_ORC_JIT_MEMORY_MAPPER_H_

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h"
#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h"
#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
#include "tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h"
@ -125,8 +126,10 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options,
/*MAttrs=*/DetectMachineAttributes()))),
disassembler_(*target_machine_),
data_layout_(target_machine_->createDataLayout()),
object_layer_(
[] { return std::make_shared<llvm::SectionMemoryManager>(); }),
object_layer_([] {
return std::make_shared<llvm::SectionMemoryManager>(
orc_jit_memory_mapper::GetInstance());
}),
compile_layer_(
object_layer_,
CompilerFunctor(target_machine_.get(), &disassembler_, opt_level,

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h"
#include <stdlib.h>
#include <atomic>
#include <functional>
#include <utility>
@ -258,7 +259,9 @@ StatusOr<std::vector<uint8>> CompilePtx(const string& ptx, int cc_major,
return InternalError("couldn't get temp CUBIN file name");
}
auto cubin_cleaner = tensorflow::gtl::MakeCleanup([&cubin_path] {
TF_CHECK_OK(tensorflow::Env::Default()->DeleteFile(cubin_path));
// CUBIN file may never be created, so the failure to delete it should not
// produce TF error.
tensorflow::Env::Default()->DeleteFile(cubin_path).IgnoreError();
});
tensorflow::SubProcess ptxas_info_dumper;
std::vector<string> ptxas_args = {ptxas_path, ptx_path, "-o", cubin_path,
@ -500,10 +503,24 @@ std::vector<uint8> GpuCompiler::CompilePtxOrGetCachedResult(const string& ptx,
VLOG(2) << "Compiled PTX size:" << ptx.size()
<< " CUBIN size: " << cache_value->cubin_data.size();
} else {
LOG(WARNING)
<< "Failed to compile ptx to cubin. Will attempt to let "
"GPU driver compile the ptx. "
<< maybe_cubin.status();
bool log_warning = true;
if (maybe_cubin.status().code() ==
tensorflow::error::Code::NOT_FOUND) {
// Missing ptxas is expected in some environments where CUDA SDK
// binaries are not available. We don't want to spam logs with
// identical warnings in this case.
// TODO(zhengxq): we should implement a LOG_FIRST_N and LOG_EVERY_N
// for more general usage.
static std::atomic<bool> warning_done(false);
log_warning = !warning_done.exchange(true);
}
if (log_warning) {
LOG(WARNING)
<< "Failed to compile ptx to cubin. Will attempt to let "
"GPU driver compile the ptx. "
<< maybe_cubin.status();
}
}
}
cache_value->compilation_done = true;

View File

@ -166,11 +166,46 @@ void HloToIrBindings::BindHloToIrValue(const HloInstruction& hlo,
*(base_ptrs_[&hlo].mutable_element(shape_index)) = typed_ir_value;
}
// Determines whether hlo's buffers are never modified within the execution of
// consumer.
static bool BuffersInvariantWithinConsumer(
const HloInstruction& hlo, const HloInstruction& consumer,
const BufferAssignment* buffer_assignment) {
// Check if consumer is inside a fusion node -- if so, "dereference" it until
// we get to a non-fusion node.
const HloInstruction* c = &consumer;
while (c->IsFused()) {
c = c->parent()->FusionInstruction();
}
// If, after dereferencing c, we end up with a node that's not inside our
// module's top-level computation (say our node is inside a while loop), we
// give up on marking array as invariant, because this HLO may be run multiple
// times (e.g. multiple while loop iterations, or multiple invocations of a
// reducer's computation). TODO(jlebar): We could relax this constraint if we
// emitted an llvm.invariant.group.barrier at the end of the computation.
return c->parent() == c->GetModule()->entry_computation() &&
buffer_assignment->HaveDisjointSlices(&hlo, &consumer);
}
llvm_ir::IrArray HloToIrBindings::GetIrArray(const HloInstruction& hlo,
const HloInstruction& consumer,
const ShapeIndex& shape_index) {
llvm_ir::IrArray ir_array(GetBasePointer(hlo, shape_index),
ShapeUtil::GetSubshape(hlo.shape(), shape_index));
alias_analysis_.AddAliasingInformationToIrArray(hlo, &ir_array);
// The GPU backend emits one kernel per top-level HLO, and LLVM views
// execution of one kernel as the "whole program" executed on the GPU.
// Therefore if hlo's output buffer is not modified within consumer, and if
// consumer runs hlo only once (so that it doesn't create two different
// outputs), then we can mark ir_array as invariant over the whole program.
if (BuffersInvariantWithinConsumer(hlo, consumer, buffer_assignment_)) {
VLOG(2) << "Marking " << hlo.name() << " as invariant within "
<< consumer.name();
ir_array.MarkInvariantOverWholeProgram(&module_->getContext());
}
return ir_array;
}

View File

@ -76,8 +76,15 @@ class HloToIrBindings {
return it->second.element(shape_index);
}
// Return the underlying IrArray of the output of the given instruction.
// Returns the IrArray which contains the output of hlo.
//
// consumer is the HLO in which this IrArray is used -- we use this to (try
// to) add metadata indicating that the array is invariant within consumer.
//
// To get the buffer into which hlo should write its own output, call
// GetIrArray(hlo, hlo).
llvm_ir::IrArray GetIrArray(const HloInstruction& hlo,
const HloInstruction& consumer,
const ShapeIndex& shape_index = {});
private:

View File

@ -68,7 +68,8 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) {
ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator;
for (const HloInstruction* operand : hlo->operands()) {
operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) {
return GetIrArray(*operand).EmitReadArrayElement(index, &ir_builder_);
return GetIrArray(*operand, *hlo)
.EmitReadArrayElement(index, &ir_builder_);
};
}
return EmitTargetElementLoop(
@ -145,7 +146,8 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) {
for (const HloInstruction* operand : tuple->operands()) {
base_ptrs.push_back(GetBasePointer(*operand));
}
llvm_ir::EmitTuple(GetIrArray(*tuple), base_ptrs, &ir_builder_, module_);
llvm_ir::EmitTuple(GetIrArray(*tuple, *tuple), base_ptrs, &ir_builder_,
module_);
return Status::OK();
}
@ -334,7 +336,8 @@ Status IrEmitter::HandleSelect(HloInstruction* select) {
TF_RET_CHECK(pred->shape().element_type() == PRED);
if (ShapeUtil::IsTuple(select->shape())) {
llvm_ir::EmitTupleSelect(GetIrArray(*select), GetIrArray(*pred),
llvm_ir::EmitTupleSelect(GetIrArray(*select, *select),
GetIrArray(*pred, *select),
GetBasePointer(*on_true),
GetBasePointer(*on_false), &ir_builder_, module_);
return Status::OK();
@ -349,9 +352,9 @@ Status IrEmitter::HandleSelect(HloInstruction* select) {
Status IrEmitter::HandleDot(HloInstruction* dot) {
auto lhs_instruction = dot->operand(0);
auto rhs_instruction = dot->operand(1);
const llvm_ir::IrArray& target_array = GetIrArray(*dot);
const llvm_ir::IrArray& lhs_array = GetIrArray(*lhs_instruction);
const llvm_ir::IrArray& rhs_array = GetIrArray(*rhs_instruction);
const llvm_ir::IrArray& target_array = GetIrArray(*dot, *dot);
const llvm_ir::IrArray& lhs_array = GetIrArray(*lhs_instruction, *dot);
const llvm_ir::IrArray& rhs_array = GetIrArray(*rhs_instruction, *dot);
const Shape& lhs_shape = lhs_instruction->shape();
const Shape& rhs_shape = rhs_instruction->shape();
@ -571,7 +574,8 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) {
// Apply the reduction function to the loaded value.
llvm::Value* input_address =
GetIrArray(*arg).EmitArrayElementAddress(input_index, &ir_builder_);
GetIrArray(*arg, *reduce)
.EmitArrayElementAddress(input_index, &ir_builder_);
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*function, {accumulator_addr, input_address}, accumulator_addr));
@ -587,7 +591,7 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) {
std::vector<llvm_ir::IrArray> parameter_arrays;
for (HloInstruction* operand : fusion->operands()) {
parameter_arrays.push_back(GetIrArray(*operand));
parameter_arrays.push_back(GetIrArray(*operand, *fusion));
}
GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_,
&ir_builder_, GetNestedComputer());
@ -622,7 +626,8 @@ Status IrEmitter::HandleRng(HloInstruction* random) {
ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator;
for (const HloInstruction* operand : random->operands()) {
operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) {
return GetIrArray(*operand).EmitReadArrayElement(index, &ir_builder_);
return GetIrArray(*operand, *random)
.EmitReadArrayElement(index, &ir_builder_);
};
}
// Emits a single-threaded loop because the loop body generated by the element
@ -631,7 +636,7 @@ Status IrEmitter::HandleRng(HloInstruction* random) {
GpuElementalIrEmitter(hlo_module_config_, module_, &ir_builder_,
GetNestedComputer())
.MakeElementGenerator(random, operand_to_generator),
GetIrArray(*random), &ir_builder_)
GetIrArray(*random, *random), &ir_builder_)
.EmitLoop(IrName(random));
}

View File

@ -105,10 +105,16 @@ class IrEmitter : public DfsHloVisitorWithDefault {
explicit IrEmitter(const HloModuleConfig& hlo_module_config,
IrEmitterContext* ir_emitter_context, bool is_nested);
// A convenient helper for calling HloToIrBindings::GetIrArray.
// Helper for calling HloToIrBindings::GetIrArray.
//
// Gets the IrArray which contains inst. This array has metadata that makes
// it valid only within the IR that implements consumer. If you are
// implementing an HLO and want to get its own output buffer, call
// GetIrArray(hlo, hlo).
llvm_ir::IrArray GetIrArray(const HloInstruction& inst,
const HloInstruction& consumer,
const ShapeIndex& shape_index = {}) {
return bindings_.GetIrArray(inst, shape_index);
return bindings_.GetIrArray(inst, consumer, shape_index);
}
// A convenient helper for calling HloToIrBindings::GetBasePointer.
llvm::Value* GetBasePointer(const HloInstruction& inst) const {

View File

@ -115,7 +115,8 @@ Status IrEmitterNested::HandleParameter(HloInstruction* parameter) {
Status IrEmitterNested::EmitTargetElementLoop(
const HloInstruction& hlo,
const llvm_ir::ElementGenerator& element_generator) {
return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo), &ir_builder_)
return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo, hlo),
&ir_builder_)
.EmitLoop();
}

View File

@ -282,7 +282,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
MakeUnique<SequentialThunk>(std::move(thunks), fusion));
std::vector<llvm_ir::IrArray> parameter_arrays;
for (HloInstruction* operand : fusion->operands()) {
parameter_arrays.push_back(GetIrArray(*operand));
parameter_arrays.push_back(GetIrArray(*operand, *fusion));
}
GpuElementalIrEmitter elemental_emitter(
hlo_module_config_, ir_emitter_context_->llvm_module(),
@ -344,7 +344,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
thunk_sequence_->emplace_back(BuildKernelThunk(fusion));
std::vector<llvm_ir::IrArray> operand_arrays;
for (HloInstruction* operand : fusion->operands()) {
operand_arrays.push_back(GetIrArray(*operand));
operand_arrays.push_back(GetIrArray(*operand, *fusion));
}
GpuElementalIrEmitter elemental_emitter(hlo_module_config_,
ir_emitter_context_->llvm_module(),
@ -355,7 +355,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
// Array to write into. Because this is an in-place operation, this is the
// same as operand 0's array.
llvm_ir::IrArray output_array = GetIrArray(*fusion);
llvm_ir::IrArray output_array = GetIrArray(*fusion, *fusion);
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
update_shape, ir_emitter_context_->device_description());
@ -693,9 +693,10 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) {
constexpr int64 tile_size = 32;
constexpr int64 num_rows = 8;
int64 num_tiles = EmitTranspose021Tiled(
GetIrArray(*(copy->operand(0)))
GetIrArray(*copy->operand(0), *copy)
.CastToShape(reduced_input_shape, &ir_builder_),
GetIrArray(*copy).CastToShape(reduced_output_shape, &ir_builder_),
GetIrArray(*copy, *copy)
.CastToShape(reduced_output_shape, &ir_builder_),
tile_size, num_rows, &ir_builder_);
UpdateLaunchDimensions(LaunchDimensions(num_tiles, num_rows * tile_size),
LastThunk(), ir_emitter_context_->llvm_module());
@ -850,9 +851,11 @@ Status IrEmitterUnnested::EmitColumnReduction(
&ir_builder_);
const HloInstruction* output =
reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce;
llvm::Value* output_address = GetIrArray(*output).EmitArrayElementAddress(
llvm_ir::IrArray::Index(x, output->shape(), &ir_builder_), &ir_builder_,
"output_element_address");
llvm::Value* output_address =
GetIrArray(*output, *output)
.EmitArrayElementAddress(
llvm_ir::IrArray::Index(x, output->shape(), &ir_builder_),
&ir_builder_, "output_element_address");
return EmitAtomicOperationForNestedComputation(
*reducer, output_address, partial_reduction_result_address);
};
@ -1116,9 +1119,11 @@ Status IrEmitterUnnested::EmitRowReduction(
"lane_id_is_zero", &ir_builder_);
llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block,
&ir_builder_);
llvm::Value* output_address = GetIrArray(*output).EmitArrayElementAddress(
llvm_ir::IrArray::Index(y, output->shape(), &ir_builder_), &ir_builder_,
"output_element_address");
llvm::Value* output_address =
GetIrArray(*output, *output)
.EmitArrayElementAddress(
llvm_ir::IrArray::Index(y, output->shape(), &ir_builder_),
&ir_builder_, "output_element_address");
return EmitAtomicOperationForNestedComputation(
*reducer, output_address, partial_reduction_result_address);
};
@ -1258,11 +1263,12 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) {
MakeUnique<SequentialThunk>(std::move(thunks), reduce));
return EmitReductionToVector(
reduce, input->shape(),
[this, input](const llvm_ir::IrArray::Index& index) {
return GetIrArray(*input).EmitReadArrayElement(index, &ir_builder_);
[&](const llvm_ir::IrArray::Index& index) {
return GetIrArray(*input, *reduce)
.EmitReadArrayElement(index, &ir_builder_);
},
[this, init_value](const llvm_ir::IrArray::Index& index) {
return GetIrArray(*init_value)
[&](const llvm_ir::IrArray::Index& index) {
return GetIrArray(*init_value, *reduce)
.EmitReadArrayElement(index, &ir_builder_);
},
dimensions_to_reduce, reducer);
@ -1426,7 +1432,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
ir_builder_.CreateStore(operand_index[i], selected_index_address_slot);
}
};
llvm_ir::IrArray operand_array(GetIrArray(*operand));
llvm_ir::IrArray operand_array = GetIrArray(*operand, *select_and_scatter);
llvm::Value* operand_data =
operand_array.EmitReadArrayElement(operand_index, &ir_builder_);
ir_builder_.CreateStore(operand_data, selected_value_address);
@ -1479,9 +1485,10 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
ir_builder_.CreateLoad(selected_index_address_slot));
}
llvm::Value* source_value_address =
GetIrArray(*source).EmitArrayElementAddress(source_index, &ir_builder_);
GetIrArray(*source, *select_and_scatter)
.EmitArrayElementAddress(source_index, &ir_builder_);
llvm::Value* output_value_address =
GetIrArray(*select_and_scatter)
GetIrArray(*select_and_scatter, *select_and_scatter)
.EmitArrayElementAddress(selected_index, &ir_builder_);
return EmitAtomicOperationForNestedComputation(
*select_and_scatter->scatter(), output_value_address,
@ -1758,7 +1765,7 @@ Status IrEmitterUnnested::EmitInitializer(const HloInstruction* hlo,
return EmitTargetElementLoopInThunk(
*hlo,
[=](const llvm_ir::IrArray::Index& index) {
return GetIrArray(*init_value)
return GetIrArray(*init_value, *hlo)
.EmitReadArrayElement(index, &ir_builder_);
},
thunk);
@ -1859,7 +1866,7 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
UpdateLaunchDimensions(launch_dimensions, thunk,
ir_emitter_context_->llvm_module());
if (!hlo.IsMultiOutputFusion()) {
return ParallelLoopEmitter(element_generator, GetIrArray(hlo),
return ParallelLoopEmitter(element_generator, GetIrArray(hlo, hlo),
launch_dimensions, &ir_builder_)
.EmitLoop(IrName(&hlo));
}
@ -1867,7 +1874,7 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
// For multiple outputs fusion, we need to emit each operand and the root.
std::vector<llvm_ir::IrArray> output_arrays;
for (int64 i = 0; i < ShapeUtil::TupleElementCount(hlo.shape()); ++i) {
output_arrays.push_back(GetIrArray(hlo, {i}));
output_arrays.push_back(GetIrArray(hlo, hlo, {i}));
}
TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, output_arrays,
launch_dimensions, &ir_builder_)
@ -1878,7 +1885,7 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
}
ir_builder_.SetInsertPoint(ir_builder_.GetInsertBlock()->getTerminator());
llvm_ir::EmitTuple(GetIrArray(hlo), tuple_operand_ptrs, &ir_builder_,
llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &ir_builder_,
module_);
return Status::OK();
}

View File

@ -75,11 +75,43 @@ HloValue* HloDataflowAnalysis::NewHloValue(HloInstruction* instruction,
std::forward_as_tuple(value_id, instruction, index, is_phi));
CHECK(emplaced.second);
VLOG(4) << "NewHloValue = " << emplaced.first->second.ToShortString();
return &emplaced.first->second;
}
void HloDataflowAnalysis::DeleteHloValue(HloValue::Id value_id) {
values_.erase(value_id);
void HloDataflowAnalysis::MarkValueForDeletion(HloValue::Id value_id) {
HloValue& value = values_.at(value_id);
VLOG(4) << "MarkValueForDeletion(" << value.ToShortString() << ")";
value_ids_to_delete_.push_back(value_id);
}
void HloDataflowAnalysis::DeleteMarkedValues() {
#ifndef NDEBUG
// Verify that no marked-for-deletion values are in any of the value sets.
tensorflow::gtl::FlatSet<HloValue::Id> id_set(value_ids_to_delete_.begin(),
value_ids_to_delete_.end());
for (const auto& pair : value_sets_) {
const HloInstruction* instruction = pair.first;
const InstructionValueSet& instruction_value_set = pair.second;
for (const auto& index_value_set : instruction_value_set) {
const HloValueSet& value_set = index_value_set.second;
for (const HloValue* value : value_set.values()) {
DCHECK(!ContainsKey(id_set, value->id()))
<< "Value " << value->ToShortString()
<< " marked for deletion, but still exists in value set for "
"instruction "
<< instruction->name();
}
}
}
#endif
for (HloValue::Id value_id : value_ids_to_delete_) {
values_.erase(value_id);
}
value_ids_to_delete_.clear();
}
string HloDataflowAnalysis::ToString() const {
@ -121,6 +153,7 @@ bool HloDataflowAnalysis::Phi(
HloInstruction* instruction,
tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs) {
CHECK(ssa_form_);
VLOG(4) << "Phi(" << instruction->name() << ")";
for (const InstructionValueSet* input : inputs) {
DCHECK(ShapeUtil::Compatible(instruction->shape(), input->shape()));
@ -183,7 +216,7 @@ bool HloDataflowAnalysis::Phi(
} else if (current_value != &new_value) {
if (current_value_defined_here) {
// Remove the existing phi.
DeleteHloValue(current_value->id());
MarkValueForDeletion(current_value->id());
}
value_set.Clear();
value_set.AddValue(&new_value);
@ -193,7 +226,8 @@ bool HloDataflowAnalysis::Phi(
// Multiple distinct values reach this point. A phi value is
// necessary.
CHECK_GT(input_value_ids.size(), 1);
if (current_value == nullptr || !current_value->is_phi()) {
if (current_value == nullptr ||
!(current_value->is_phi() && current_value_defined_here)) {
value_set.Clear();
value_set.AddValue(NewHloValue(instruction, index, /*is_phi=*/true));
changed = true;
@ -485,11 +519,13 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet(
}
}
void HloDataflowAnalysis::UpdateInstructionsAndPropagate(
tensorflow::gtl::ArraySlice<HloInstruction*> instructions) {
void HloDataflowAnalysis::Propagate() {
std::queue<HloInstruction*> worklist;
for (HloInstruction* instruction : instructions) {
worklist.push(instruction);
for (HloComputation* computation : module_->computations()) {
for (HloInstruction* instruction : computation->instructions()) {
worklist.push(instruction);
}
}
while (!worklist.empty()) {
@ -662,20 +698,17 @@ StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
new HloDataflowAnalysis(module, ssa_form, bitcast_defines_value));
TF_RETURN_IF_ERROR(dataflow_analysis->InitializeInstructionValueSets());
dataflow_analysis->Propagate();
// Construct list of all instructions to initialize the worklist to propagate
// the data flow. For efficiency sort the instruction in post order so
// producers appear before consumers.
std::vector<HloInstruction*> all_instructions;
for (const HloComputation* computation : module->MakeComputationPostOrder()) {
for (HloInstruction* instruction :
computation->MakeInstructionPostOrder()) {
all_instructions.push_back(instruction);
}
}
dataflow_analysis->UpdateInstructionsAndPropagate(all_instructions);
// Delete all values marked for deletion.
dataflow_analysis->DeleteMarkedValues();
// Add in positions to all values.
// Gather and set all non-definition positions of all values. Value deletion
// is rare, so just use a vector indexed by Value::Id rather than a map from
// Value::Id to positions. There should be very few holes in the vector, and
// lookup is faster.
std::vector<std::vector<HloPosition>> value_positions(
dataflow_analysis->next_value_id_);
for (const HloComputation* computation : module->computations()) {
for (HloInstruction* instruction : computation->instructions()) {
for (const auto& pair :
@ -684,13 +717,18 @@ StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
const HloValueSet& value_set = pair.second;
for (const HloValue* value : value_set.values()) {
if (value->defining_instruction() != instruction) {
dataflow_analysis->GetValue(value->id())
.AddPosition(instruction, index);
value_positions[value->id()].push_back(
HloPosition{instruction, index});
}
}
}
}
}
for (auto& pair : dataflow_analysis->values_) {
HloValue::Id value_id = pair.first;
HloValue& value = pair.second;
value.SetPositionsAndComputeUses(value_positions[value_id]);
}
// Construct vector of values.
dataflow_analysis->values_vector_.reserve(dataflow_analysis->values_.size());

View File

@ -126,13 +126,16 @@ class HloDataflowAnalysis {
HloValue* NewHloValue(HloInstruction* instruction, const ShapeIndex& index,
bool is_phi = false);
// Delete the HloValue with the given ID.
void DeleteHloValue(HloValue::Id value_id);
// Mark the HloValue with the given ID for deletion.
void MarkValueForDeletion(HloValue::Id value_id);
// Delete all HloValues marked for deletion. Should be called after
// propagation is complete.
void DeleteMarkedValues();
// Constructs and initializes the InstructionValueSets of all instructions to
// contain exactly the HloValues defined by each instruction. These values can
// then propagated throughout the HLO graph by calling
// UpdateInstructionsAndPropagate.
// then propagated throughout the HLO graph by calling Propagate.
Status InitializeInstructionValueSets();
// Updates the value set of the given instruction based on the values flowing
@ -152,10 +155,8 @@ class HloDataflowAnalysis {
bool UpdateTupleValueSet(HloInstruction* tuple);
bool UpdateWhileValueSet(HloInstruction* xla_while);
// Update the value sets of the given instructions and propagate the
// changes to fixed point.
void UpdateInstructionsAndPropagate(
tensorflow::gtl::ArraySlice<HloInstruction*> instructions);
// Propagate the dataflow through the module.
void Propagate();
// Return the result of the SSA Phi function applied to the given inputs at
// the given instruction. If skip_top_level is true, then the top level of the
@ -191,6 +192,11 @@ class HloDataflowAnalysis {
// A map from instruction to InstructionValueSet.
std::unordered_map<const HloInstruction*, InstructionValueSet> value_sets_;
// Values marked for deletion during construction. We don't delete them
// immediately because references to them may remain in ValueSets temporarily
// during propagation. After construction, these values are deleted.
std::vector<HloValue::Id> value_ids_to_delete_;
// A vector containing all HloValues sorted by HloValue::Id.
std::vector<const HloValue*> values_vector_;

View File

@ -211,10 +211,10 @@ TEST_P(HloDataflowAnalysisTest, NestedTuple) {
HloPosition{nested_tuple, {0, 0}}, HloPosition{nested_tuple, {1, 0}},
HloPosition{nested_tuple, {2}}, HloPosition{gte_tuple, {0}},
HloPosition{gte_out, {}}));
// Constant values should have no uses though one is live out. The positions
// where they appear as operands are on instructions which do not use the
// values (eg, Tuple).
EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).uses().empty());
// Constant values should have only a single use, which is the root of the
// computation.
EXPECT_THAT(analysis.GetValueDefinedAt(constant1, /*index=*/{}).uses(),
UnorderedElementsAre(HloUse{gte_out, 0, {0}}));
EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).uses().empty());
// The top-level tuple values are used in GTE instructions.
@ -274,12 +274,11 @@ TEST_P(HloDataflowAnalysisTest, SingleCall) {
EXPECT_EQ(analysis.GetUniqueValueAt(call), analysis.GetValueDefinedAt(add));
EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
UnorderedElementsAre(HloUse{add, 0, {}}));
UnorderedElementsAre(HloUse{call, 0, {}}, HloUse{add, 0, {}}));
EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(),
UnorderedElementsAre(HloUse{add, 1, {}}));
UnorderedElementsAre(HloUse{call, 1, {}}, HloUse{add, 1, {}}));
EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation());
}
TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithSameArguments) {
@ -323,18 +322,17 @@ TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithSameArguments) {
EXPECT_TRUE(analysis.ValueIsDefinedAt(sub));
EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
UnorderedElementsAre(HloUse{add, 0, {}}));
UnorderedElementsAre(HloUse{call1, 0, {}}, HloUse{call2, 0, {}},
HloUse{add, 0, {}}));
EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(),
UnorderedElementsAre(HloUse{add, 1, {}}));
UnorderedElementsAre(HloUse{call1, 1, {}}, HloUse{call2, 1, {}},
HloUse{add, 1, {}}));
// The Add from the subcomputation is used as both operands of the Subtract.
EXPECT_THAT(analysis.GetValueDefinedAt(add).uses(),
UnorderedElementsAre(HloUse{sub, 0, {}}, HloUse{sub, 1, {}}));
EXPECT_FALSE(analysis.GetValueDefinedAt(add).live_out_of_module());
EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation());
EXPECT_TRUE(analysis.GetValueDefinedAt(sub).live_out_of_module());
EXPECT_TRUE(analysis.GetValueDefinedAt(sub).live_out_of_computation());
}
TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithDifferentArguments) {
@ -408,7 +406,7 @@ TEST_P(HloDataflowAnalysisTest, NestedCalls) {
auto outer_param1 = outer_builder.AddInstruction(
HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
// Swizzle parameters.
outer_builder.AddInstruction(HloInstruction::CreateCall(
auto nested_call = outer_builder.AddInstruction(HloInstruction::CreateCall(
scalar_shape_, {outer_param1, outer_param0}, inner_computation));
HloComputation* outer_computation =
module_->AddEmbeddedComputation(outer_builder.Build());
@ -418,7 +416,7 @@ TEST_P(HloDataflowAnalysisTest, NestedCalls) {
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
builder.AddInstruction(HloInstruction::CreateCall(
auto call = builder.AddInstruction(HloInstruction::CreateCall(
scalar_shape_, {constant1, constant2}, outer_computation));
module_->AddEntryComputation(builder.Build());
@ -431,10 +429,14 @@ TEST_P(HloDataflowAnalysisTest, NestedCalls) {
// Verify that the uses of the constants are properly swizzled by parameter
// permutation in nested_call.
EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
UnorderedElementsAre(HloUse{add, 1, {}}));
EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(),
UnorderedElementsAre(HloUse{add, 0, {}}));
EXPECT_THAT(
analysis.GetValueDefinedAt(constant1).uses(),
UnorderedElementsAre(HloUse{call, 0, {}}, HloUse{nested_call, 1, {}},
HloUse{add, 1, {}}));
EXPECT_THAT(
analysis.GetValueDefinedAt(constant2).uses(),
UnorderedElementsAre(HloUse{call, 1, {}}, HloUse{nested_call, 0, {}},
HloUse{add, 0, {}}));
EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
}
@ -469,7 +471,7 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) {
HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1));
body_builder.AddInstruction(
auto body_root = body_builder.AddInstruction(
HloInstruction::CreateTuple({body_element_0, add}));
HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
@ -496,8 +498,6 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) {
bool ssa_form = GetParam();
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
EXPECT_TRUE(
analysis.GetValueDefinedAt(cond_constant).live_out_of_computation());
EXPECT_FALSE(analysis.GetValueDefinedAt(cond_constant).live_out_of_module());
if (ssa_form) {
@ -517,14 +517,14 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) {
EXPECT_THAT(
analysis.GetValueDefinedAt(constant1).uses(),
UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{xla_while, 0, {0}}));
UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{body_root, 0, {}},
HloUse{xla_while, 0, {0}}));
// Constant1 passes through the body and out of the module.
EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1})
.live_out_of_module());
EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation());
EXPECT_FALSE(analysis.GetValueDefinedAt(add).live_out_of_module());
} else {
// While instruction and subcomputation parameters should not define values
@ -538,7 +538,6 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) {
EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation());
}
}
@ -915,9 +914,11 @@ TEST_P(HloDataflowAnalysisTest, TupleSelect) {
HloUse{select12, 1, {}}));
// The two constant values just pass through the Selects and are not
// used. They are live out however.
EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).uses().empty());
EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).uses().empty());
// used except at the root. They are live out however.
EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
UnorderedElementsAre(HloUse{select1234, 1, {0}}));
EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(),
UnorderedElementsAre(HloUse{select1234, 1, {0}}));
EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
}
@ -1318,7 +1319,7 @@ TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) {
auto entry = module_->AddEntryComputation(builder.Build());
bool ssa_form = GetParam();
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
RunAnalysis(ssa_form);
SequentialHloOrdering::HloModuleSequence sequence;
sequence.insert({entry, {param, xla_while}});
@ -1329,12 +1330,6 @@ TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) {
SequentialHloOrdering ordering(module_.get(), sequence);
// 'add' is the body root even though later instructions follow in the order
// like 'dead_negate'. Only 'add' should be live out of the computation.
EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation());
EXPECT_FALSE(
analysis.GetValueDefinedAt(dead_negate).live_out_of_computation());
// 'add' is live out of the body and will interfere with an later instructions
// such as 'dead_constant' and 'dead_negate'.
EXPECT_TRUE(InstructionsMayInterfere(ordering, add, dead_constant));

View File

@ -26,7 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
namespace xla {
HloToProfileIndex::HloToProfileIndex(const HloModule& module) {
HloProfileIndexMap::HloProfileIndexMap(const HloModule& module) {
size_t current_profile_index = 0;
for (xla::HloComputation* computation : module.MakeComputationPostOrder()) {
InsertOrDie(&computation_to_profile_idx_, computation,
@ -41,24 +41,24 @@ HloToProfileIndex::HloToProfileIndex(const HloModule& module) {
}
static HloProfilePrinter CreateOwnedHloProfilePrinter(
const HloToProfileIndex& hlo_to_profile_index,
const HloProfileIndexMap& hlo_profile_index_map,
const HloCostAnalysis& cost_analysis) {
using HloComputationInfo = HloProfilePrinter::HloComputationInfo;
using HloInstructionInfo = HloProfilePrinter::HloInstructionInfo;
HloComputationInfo* computation_infos =
new HloComputationInfo[hlo_to_profile_index.computation_count()];
new HloComputationInfo[hlo_profile_index_map.computation_count()];
// There are two "indices" in play here. The first one is the index of the
// HloComputationInfo or HloInstructionInfo in the array that contains said
// HloComputationInfo or HloInstructionInfo. The second index is the index of
// the HloComputationInfo or HloInstructionInfo in the profile counters array,
// as decided by hlo_to_profile_index. The latter index is always referred to
// as "profile_index".
// as decided by hlo_profile_index_map. The latter index is always referred
// to as "profile_index".
size_t computation_index_in_static_data = 0;
size_t max_profile_index = hlo_to_profile_index.total_count();
for (const auto& pair : hlo_to_profile_index.computation_to_profile_idx()) {
size_t max_profile_index = hlo_profile_index_map.total_count();
for (const auto& pair : hlo_profile_index_map.computation_to_profile_idx()) {
CHECK_LT(pair.second, max_profile_index);
const HloComputation* computation = pair.first;
size_t current_computation_index = computation_index_in_static_data++;
@ -85,7 +85,7 @@ static HloProfilePrinter CreateOwnedHloProfilePrinter(
instruction_info->bytes_accessed = cost_analysis.bytes_accessed(*hlo);
instruction_info->seconds = cost_analysis.seconds(*hlo);
instruction_info->profile_index =
hlo_to_profile_index.GetProfileIndexFor(*hlo);
hlo_profile_index_map.GetProfileIndexFor(*hlo);
CHECK_LT(instruction_info->profile_index, max_profile_index);
}
}
@ -109,26 +109,26 @@ static HloProfilePrinter CreateOwnedHloProfilePrinter(
};
return HloProfilePrinter(computation_infos,
hlo_to_profile_index.computation_count(), deleter);
hlo_profile_index_map.computation_count(), deleter);
}
HloExecutionProfile::HloExecutionProfile(const HloModule& module,
const HloCostAnalysis& cost_analysis)
: hlo_to_profile_index_(module),
: hlo_profile_index_map_(module),
hlo_profile_printer_(
CreateOwnedHloProfilePrinter(hlo_to_profile_index_, cost_analysis)),
CreateOwnedHloProfilePrinter(hlo_profile_index_map_, cost_analysis)),
profile_counters_(
/*count*/ hlo_to_profile_index_.total_count(),
/*count*/ hlo_profile_index_map_.total_count(),
/*value*/ 0) {}
void HloExecutionProfile::SetCyclesTakenBy(const HloInstruction* hlo,
uint64 cycles_taken) {
profile_counters_[hlo_to_profile_index_.GetProfileIndexFor(*hlo)] =
profile_counters_[hlo_profile_index_map_.GetProfileIndexFor(*hlo)] =
cycles_taken;
}
uint64 HloExecutionProfile::GetCyclesTakenBy(const HloInstruction& hlo) const {
return profile_counters_[hlo_to_profile_index_.GetProfileIndexFor(hlo)];
return profile_counters_[hlo_profile_index_map_.GetProfileIndexFor(hlo)];
}
string HloExecutionProfile::ToString(

View File

@ -29,18 +29,18 @@ namespace xla {
class HloInstruction;
// Maps all HloInstructions and HloComputions in an HloModule to integers.
// These integers form the contiguous range [0, GetTotalCount()).
class HloToProfileIndex {
// Maps all HloInstructions and HloComputations in an HloModule to integers.
// These integers form the contiguous range [0, total_count()).
class HloProfileIndexMap {
public:
// Scans `module` to populate this instance of HloToProfileIndex.
explicit HloToProfileIndex(const HloModule& module);
// Scans `module` to populate this instance of HloProfileIndexMap.
explicit HloProfileIndexMap(const HloModule& module);
HloToProfileIndex(const HloToProfileIndex&) = default;
HloToProfileIndex(HloToProfileIndex&&) = default;
HloProfileIndexMap(const HloProfileIndexMap&) = default;
HloProfileIndexMap(HloProfileIndexMap&&) = default;
HloToProfileIndex& operator=(const HloToProfileIndex&) = default;
HloToProfileIndex& operator=(HloToProfileIndex&&) = default;
HloProfileIndexMap& operator=(const HloProfileIndexMap&) = default;
HloProfileIndexMap& operator=(HloProfileIndexMap&&) = default;
size_t GetProfileIndexFor(const HloInstruction& instruction) const {
return FindOrDie(instruction_to_profile_idx(), &instruction);
@ -97,14 +97,14 @@ class HloExecutionProfile {
// Return the number of cycles this computation took to execute.
uint64 total_cycles_executed(const HloComputation& computation) const {
return profile_counters_[hlo_to_profile_index_.GetProfileIndexFor(
return profile_counters_[hlo_profile_index_map_.GetProfileIndexFor(
computation)];
}
// Record how many cycles a computation took to execute.
void set_total_cycles_executed(const HloComputation& computation,
uint64 total_cycles_executed) {
profile_counters_[hlo_to_profile_index_.GetProfileIndexFor(computation)] =
profile_counters_[hlo_profile_index_map_.GetProfileIndexFor(computation)] =
total_cycles_executed;
}
@ -117,9 +117,9 @@ class HloExecutionProfile {
string ToString(const DeviceDescription& device_description) const;
private:
// hlo_to_profile_index_ maps an Hlo entity (computation or instruction) to an
// index in profile_counters_.
HloToProfileIndex hlo_to_profile_index_;
// hlo_profile_index_map_ maps an Hlo entity (computation or instruction) to
// an index in profile_counters_.
HloProfileIndexMap hlo_profile_index_map_;
// Used to print profile_counters_ in a human readable form.
HloProfilePrinter hlo_profile_printer_;

View File

@ -970,6 +970,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kOutfeed:
case HloOpcode::kCrossReplicaSum:
return kBrown;
case HloOpcode::kConditional:
case HloOpcode::kCustomCall:
case HloOpcode::kWhile:
case HloOpcode::kCall:
@ -1003,7 +1004,7 @@ string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) {
}
string extended_opcode =
StrCat(HloOpcodeString(instr->opcode()),
instr->opcode() == HloOpcode::kFusion
instr->opcode() != HloOpcode::kFusion
? ""
: StrCat(":", xla::ToString(instr->fusion_kind())));
// If the name does not contain the opcode, render both.

View File

@ -43,6 +43,7 @@ limitations under the License.
namespace xla {
using tensorflow::str_util::CEscape;
using ::tensorflow::str_util::Join;
using ::tensorflow::strings::StrAppend;
using ::tensorflow::strings::StrCat;
@ -1209,6 +1210,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
new_operands[2], new_operands[3],
new_operands[4], epsilon(), feature_index());
break;
case HloOpcode::kConditional:
case HloOpcode::kRecv:
case HloOpcode::kRecvDone:
case HloOpcode::kSend:
@ -1602,6 +1604,7 @@ bool HloInstruction::IdenticalSlowPath(
return dimensions() == other.dimensions();
// These opcodes are not yet supported.
case HloOpcode::kConditional:
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kSort:
@ -1965,6 +1968,13 @@ std::vector<string> HloInstruction::ExtraAttributesToString() const {
}),
"}"));
}
if (opcode() == HloOpcode::kInfeed && !infeed_config_.empty()) {
extra.push_back(StrCat("infeed_config=\"", CEscape(infeed_config_), "\""));
}
if (opcode() == HloOpcode::kOutfeed && !outfeed_config_.empty()) {
extra.push_back(
StrCat("outfeed_config=\"", CEscape(outfeed_config_), "\""));
}
return extra;
}
@ -2347,6 +2357,7 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandleSendDone(this);
// These opcodes are not handled here.
case HloOpcode::kConditional:
case HloOpcode::kTrace:
break;
}
@ -2920,7 +2931,6 @@ string PaddingConfigToString(const PaddingConfig& padding) {
string OpMetadataToString(const OpMetadata& metadata) {
std::vector<string> result;
using tensorflow::str_util::CEscape;
if (!metadata.op_type().empty()) {
result.push_back(StrCat("op_type=\"", CEscape(metadata.op_type()), "\""));
}

View File

@ -58,6 +58,7 @@ namespace xla {
V(kClamp, "clamp") \
V(kComplex, "complex") \
V(kConcatenate, "concatenate", kHloOpcodeIsVariadic) \
V(kConditional, "conditional") \
V(kConstant, "constant") \
V(kConvert, "convert") \
V(kConvolution, "convolution") \

View File

@ -173,6 +173,19 @@ bool HloOrdering::UseIsBeforeValueDefinition(
return true;
}
}
// The use at a call occurs before values that are defined in the called
// computation.
if (use.instruction->opcode() == HloOpcode::kCall) {
const HloInstruction* call = use.instruction;
if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
call->to_apply())) {
VLOG(4) << " use is call " << use.instruction->name()
<< " and def is in called computation";
return true;
}
}
VLOG(4) << " use is not before value";
return false;
}
@ -187,23 +200,6 @@ bool HloOrdering::LiveRangeStrictlyBefore(
return false;
}
// Live-out values from the module can never have ranges strictly before any
// other value.
if (a.live_out_of_module()) {
VLOG(4) << "a is live out of module";
return false;
}
// Live-out values of computations can never have ranges strictly before any
// other value in the computation (including values nested in
// subcomputations).
if (a.live_out_of_computation() &&
call_graph_->InstructionIsNestedIn(b.defining_instruction(),
a.defining_instruction()->parent())) {
VLOG(4) << "a is live out of computation containing b";
return false;
}
// All uses of 'a' must be before 'b' is defined.
for (const HloUse& use : a.uses()) {
if (!UseIsBeforeValueDefinition(use, b, dataflow)) {

View File

@ -71,7 +71,7 @@ HloValue::HloValue(HloValue::Id id, HloInstruction* instruction,
const ShapeIndex& index, bool is_phi)
: id_(id), is_phi_(is_phi) {
// The defining position is always the first element in the positions_ vector.
AddPosition(instruction, index);
positions_.push_back(HloPosition{instruction, index});
}
bool HloValue::operator==(const HloValue& other) const {
@ -130,18 +130,14 @@ bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index,
CHECK_LE(operand_number, 2);
return operand_number == 0 || index.empty();
case HloOpcode::kCall:
case HloOpcode::kTuple:
// These instructions always pass through their operands transparently.
return false;
case HloOpcode::kCall:
case HloOpcode::kWhile:
// Though the while instructions passes through its operands, we return
// true because in SSA form there may be a Phi at the parameter of the
// while which is considered a use of its incoming value because the Phi
// input values are not passed through into the body computation. Because
// this function is used in both SSA and non-SSA forms of the analysis
// conservatively return true.
// Although call and while instructions pass through their operands, they
// are considered uses.
return true;
default:
@ -151,103 +147,58 @@ bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index,
} // namespace
void HloValue::AddPosition(HloInstruction* instruction,
const ShapeIndex& index) {
HloPosition new_position{instruction, index};
void HloValue::SetPositionsAndComputeUses(
tensorflow::gtl::ArraySlice<HloPosition> positions) {
CHECK_EQ(positions_.size(), 1) << "SetPositions should only be called once.";
// The new position must not already exist in positions_.
// The positions must be unique and should not contain the defining position
// as this is added at construction time.
for (const HloPosition& position_a : positions) {
DCHECK_NE(position_a, defining_position());
for (const HloPosition& position_b : positions) {
if (&position_a != &position_b) {
DCHECK_NE(position_a, position_b);
}
}
}
positions_.insert(positions_.end(), positions.begin(), positions.end());
// Gather the computation roots at which this value appears.
tensorflow::gtl::FlatSet<HloInstruction*> root_positions;
for (const HloPosition& position : positions_) {
DCHECK_NE(position, new_position);
}
positions_.push_back(std::move(new_position));
// Update uses.
for (HloInstruction* user : instruction->users()) {
for (int64 operand_number : user->OperandIndices(instruction)) {
if (MayUseOperandValue(operand_number, index, user)) {
HloUse new_use{user, operand_number, index};
// The new use must not already exist in uses_.
for (const HloUse& use : uses_) {
DCHECK_NE(use, new_use);
}
uses_.push_back(std::move(new_use));
}
if (position.instruction ==
position.instruction->parent()->root_instruction()) {
root_positions.insert(position.instruction);
}
}
// Update liveout status of this HloValue.
const HloModule& module = *instruction->parent()->parent();
if (instruction == module.entry_computation()->root_instruction()) {
live_out_of_module_ = true;
}
if (instruction == instruction->parent()->root_instruction()) {
live_out_of_computation_ = true;
}
}
void HloValue::RemovePosition(HloInstruction* instruction,
const ShapeIndex& index) {
// The defining position cannot be removed.
CHECK(!(instruction == defining_instruction() && index == defining_index()));
int64 size_before = positions_.size();
positions_.erase(
std::remove_if(positions_.begin(), positions_.end(),
[instruction, &index](const HloPosition& position) {
return position.instruction == instruction &&
position.index == index;
}),
positions_.end());
// Only a single position should have been removed.
CHECK_EQ(positions_.size(), size_before - 1);
// Update uses which referred to this position.
uses_.erase(std::remove_if(uses_.begin(), uses_.end(),
[instruction, &index](const HloUse& use) {
return use.instruction->operand(
use.operand_number) == instruction &&
use.operand_index == index;
}),
uses_.end());
// Returns whether this value is contained in the given instruction's output.
auto is_contained_in = [this](const HloInstruction* instruction) {
for (const HloPosition& position : positions()) {
if (position.instruction == instruction) {
return true;
}
}
return false;
};
const HloModule& module = *instruction->parent()->parent();
if (instruction == module.entry_computation()->root_instruction()) {
// Value has been removed from a position in the entry root instruction.
live_out_of_module_ =
is_contained_in(module.entry_computation()->root_instruction());
}
if (instruction == defining_instruction()->parent()->root_instruction()) {
// Value has been removed from the root of the computation the value has
// been defined in.
live_out_of_computation_ =
is_contained_in(defining_instruction()->parent()->root_instruction());
}
}
void HloValue::RecomputeUses() {
uses_.clear();
for (const HloPosition& position : positions()) {
// Build vector of HloUses for the value.
for (const HloPosition& position : positions_) {
for (HloInstruction* user : position.instruction->users()) {
for (int64 operand_number : user->OperandIndices(position.instruction)) {
if (MayUseOperandValue(operand_number, position.index, user)) {
uses_.push_back(HloUse{user, operand_number, position.index});
// Root instructions of computations are considered to be uses whether
// or not the root instruction itself actually uses the value.
if (MayUseOperandValue(operand_number, position.index, user) ||
ContainsKey(root_positions, user)) {
HloUse new_use{user, operand_number, position.index};
// The new use must not already exist in uses_.
for (const HloUse& use : uses_) {
DCHECK_NE(use, new_use);
}
uses_.push_back(std::move(new_use));
}
}
}
// Update liveout status of this HloValue.
const HloModule& module = *position.instruction->parent()->parent();
if (position.instruction ==
module.entry_computation()->root_instruction()) {
live_out_of_module_ = true;
}
}
}

View File

@ -121,6 +121,12 @@ class HloValue {
HloValue(Id id, HloInstruction* instruction, const ShapeIndex& index,
bool is_phi = false);
// Sets the positions in the module at which the HloValue appears. Updates
// uses. Should be called once and only once. The defining position should not
// be included in 'positions' as this is set at construction time.
void SetPositionsAndComputeUses(
tensorflow::gtl::ArraySlice<HloPosition> positions);
// Return a unique identifier for this HloValue. This value is used for stable
// sorting and iteration
Id id() const { return id_; }
@ -143,28 +149,15 @@ class HloValue {
// Return the shape of this HloValue.
const Shape& shape() const { return defining_position().shape(); }
// Add or remove a position at which the HloValue appears. The definition
// position can not be removed. The uses of the HloValue are updated.
void AddPosition(HloInstruction* instruction, const ShapeIndex& index);
void RemovePosition(HloInstruction* instruction, const ShapeIndex& index);
// Remove all positions except the defining position. Updates uses.
void ClearPositions();
// Return all positions of the HloValue in the module.
const std::vector<HloPosition>& positions() const { return positions_; }
// Return all uses of the HloValue.
const std::vector<HloUse>& uses() const { return uses_; }
void RecomputeUses();
// Get whether this HloValue is live out of the module.
bool live_out_of_module() const { return live_out_of_module_; }
// Get whether this HloValue is live out of the computation it is defined in.
bool live_out_of_computation() const { return live_out_of_computation_; }
bool operator==(const HloValue& other) const;
bool operator!=(const HloValue& other) const;

View File

@ -92,6 +92,7 @@ namespace xla {
case HloOpcode::kBatchNormInference:
case HloOpcode::kBatchNormGrad:
case HloOpcode::kCall:
case HloOpcode::kConditional:
case HloOpcode::kConvolution:
case HloOpcode::kCrossReplicaSum:
case HloOpcode::kCustomCall:

View File

@ -83,7 +83,7 @@ void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo,
if (std::find(parameter_instructions.begin(),
parameter_instructions.end(),
&hlo) != parameter_instructions.end()) {
array->AddInvariantLoad(llvm::MDNode::get(*context_, /*MDs=*/{}));
array->MarkInvariantOverWholeProgram(context_);
}
}
}

View File

@ -256,10 +256,10 @@ void IrArray::AnnotateLoadStoreInstructionWithMetadata(
llvm::Instruction* instruction) const {
CHECK(llvm::isa<llvm::LoadInst>(instruction) ||
llvm::isa<llvm::StoreInst>(instruction));
CHECK(!llvm::isa<llvm::StoreInst>(instruction) || !is_invariant_)
<< "Trying to create a store to an invariant IRArray.";
for (const auto& kind_md_pair : metadata_) {
CHECK(kind_md_pair.first != llvm::LLVMContext::MD_invariant_load ||
llvm::isa<llvm::LoadInst>(instruction));
instruction->setMetadata(kind_md_pair.first, kind_md_pair.second);
}
}

View File

@ -229,9 +229,33 @@ class IrArray {
AddMetadata(llvm::LLVMContext::MD_noalias, noalias);
}
void AddInvariantLoad(llvm::MDNode* invariant_load) {
CHECK_NE(invariant_load, nullptr);
AddMetadata(llvm::LLVMContext::MD_invariant_load, invariant_load);
// Promises LLVM that the data pointed to by this IrArray never changes after
// it's first loaded.
//
// The temporal scope of this promise is the "whole program" from LLVM's point
// of view, but how this translates to HLOs differs between backends.
//
// In the single-threaded CPU backend, we emit one function that
// runs all the HLOs in sequence, so the whole program is the whole HLO
// module.
//
// In the GPU backend, we emit one GPU kernel per top-level HLO (i.e. per HLO
// in the entry computation). From LLVM's perspective, launching a new kernel
// is like launching a new program, and so the whole program is one top-level
// HLO. Since the scope of the promise is smaller than in the CPU backend, we
// can mark more things as invariant in the GPU backend.
//
// Marking loads as invariant is particularly helpful on GPUs because
// invariant loads can be lowered to PTX ld.global.nc (equivalent to CUDA's
// __ldg intrinsic). These loads use a special cache, and can be
// significantly faster than regular loads.
void MarkInvariantOverWholeProgram(llvm::LLVMContext* context) {
if (is_invariant_) {
return;
}
is_invariant_ = true;
AddMetadata(llvm::LLVMContext::MD_invariant_load,
llvm::MDNode::get(*context, {}));
}
const std::map<int, llvm::MDNode*>& metadata() const { return metadata_; }
@ -261,6 +285,8 @@ class IrArray {
// loads/stores for this array. They keys are the metadata kinds and the
// values are the metadata nodes.
std::map<int, llvm::MDNode*> metadata_;
bool is_invariant_ = false;
};
} // namespace llvm_ir

View File

@ -63,6 +63,14 @@ void ShapedBuffer::clear() {
}
}
void ShapedBuffer::AddBufferAtIndex(
const perftools::gputools::DeviceMemoryBase& buffer,
const ShapeIndex& shape_index) {
*mutable_shape_index_to_buffer_entry()->mutable_element(shape_index) =
buffers().size();
mutable_buffers()->push_back(buffer);
}
const se::DeviceMemoryBase& ShapedBuffer::buffer(
const ShapeIndex& index) const {
return buffers_[shape_index_to_buffer_entry_.element(index)];

View File

@ -75,6 +75,10 @@ class ShapedBuffer {
// Set all device memory pointers in the object to null.
void clear();
// Adds a new buffer at the given shape index.
void AddBufferAtIndex(const perftools::gputools::DeviceMemoryBase& buffer,
const ShapeIndex& shape_index);
protected:
// The shape of the device buffer with layout.
const Shape shape_;

View File

@ -28,12 +28,9 @@ limitations under the License.
namespace se = ::perftools::gputools;
namespace xla {
/* static */ tensorflow::mutex*
TransferManager::platform_transfer_manager_mutex() {
static tensorflow::mutex* m = new tensorflow::mutex;
return m;
}
/* static */ tensorflow::mutex
TransferManager::platform_transfer_manager_mutex_(
tensorflow::LINKER_INITIALIZED);
/* static */ std::map<perftools::gputools::Platform::Id,
TransferManager::State>*
@ -47,7 +44,7 @@ TransferManager::GetPlatformTransferManagers() {
se::Platform::Id platform_id,
TransferManagerCreationFunction creation_function) {
tensorflow::mutex_lock lock(
*TransferManager::platform_transfer_manager_mutex());
TransferManager::platform_transfer_manager_mutex_);
auto* managers = GetPlatformTransferManagers();
CHECK(managers->find(platform_id) == managers->end());
(*managers)[platform_id].creation_function = creation_function;
@ -56,7 +53,7 @@ TransferManager::GetPlatformTransferManagers() {
/* static */ StatusOr<TransferManager*> TransferManager::GetForPlatform(
const se::Platform* platform) {
tensorflow::mutex_lock lock(
*TransferManager::platform_transfer_manager_mutex());
TransferManager::platform_transfer_manager_mutex_);
auto* managers = GetPlatformTransferManagers();
auto it = managers->find(platform->id());

View File

@ -158,11 +158,8 @@ class TransferManager {
const perftools::gputools::Platform* platform);
private:
// Routine that returns the mutex that guards the
// platform-to-transfer manager map. Done as a routine to
// ensure correct initialization ordering, since RegisterTransferManager
// can be called during program initialization time.
static tensorflow::mutex* platform_transfer_manager_mutex();
// The mutex that guards the platform-to-transfer manager map.
static tensorflow::mutex platform_transfer_manager_mutex_;
// State kept for each kind of TransferManager. Registration functions
// set up creation_function, and then we use that to lazily create

View File

@ -592,10 +592,10 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
return sizeof(uint32);
case U64:
return sizeof(uint64);
case F16:
return sizeof(float) / 2;
case BF16:
return sizeof(float) / 2;
case F16:
return sizeof(float) / 2;
case F32:
return sizeof(float);
case F64:

View File

@ -69,7 +69,10 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/service:transfer_manager",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_headers_lib",
],
)

View File

@ -248,5 +248,6 @@ def generate_backend_test_macros(backends=[]):
deps = [
"//tensorflow/compiler/xla:types",
"//tensorflow/core:lib",
"//tensorflow/core:regexp_internal",
"//tensorflow/core:test",
])

View File

@ -21,12 +21,13 @@ limitations under the License.
#include <unordered_map>
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/regexp.h"
namespace xla {
namespace {
// Mapping from test name; i.e. MyTest.MyTestCase to platforms on which it is
// disabled.
// disabled - a sequence of regexps.
using ManifestT = std::unordered_map<string, std::vector<string>>;
ManifestT ReadManifest() {
@ -66,9 +67,6 @@ ManifestT ReadManifest() {
string PrependDisabledIfIndicated(const string& test_case_name,
const string& test_name) {
// TODO(leary): this code reads the manifest for every test case instantiated
// in every file. Consider switching to a singleton or using a compile-time
// genrule instead.
ManifestT manifest = ReadManifest();
// First try full match: test_case_name.test_name
@ -83,11 +81,13 @@ string PrependDisabledIfIndicated(const string& test_case_name,
}
}
// Expect a full match vs. one of the platform regexps to disable the test.
const std::vector<string>& disabled_platforms = it->second;
string platform_string = XLA_PLATFORM;
if (std::find(disabled_platforms.begin(), disabled_platforms.end(),
platform_string) != disabled_platforms.end()) {
return "DISABLED_" + test_name;
for (const auto& s : disabled_platforms) {
if (RE2::FullMatch(/*text=*/platform_string, /*re=*/s)) {
return "DISABLED_" + test_name;
}
}
// We didn't hit in the disabled manifest entries, so don't disable it.

View File

@ -66,8 +66,10 @@ limitations under the License.
namespace xla {
// Reads a disabled manifest file (and retains it as a singleton) to resolve
// whether test cases should be disabled on a particular platform.
// Reads a disabled manifest file to resolve whether test cases should be
// disabled on a particular platform. For a test that should be disabled,
// returns DISABLED_ prepended to its name; otherwise returns the test name
// unmodified.
string PrependDisabledIfIndicated(const string& test_case_name,
const string& test_name);

View File

@ -14,8 +14,9 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
namespace xla {
@ -46,6 +47,44 @@ void PopulateWithRandomIntegralData(Literal* literal) {
}));
}
bool LooksLikeSum(const HloInstruction& instruction) {
return instruction.opcode() == HloOpcode::kAdd &&
instruction.operand(0)->opcode() == HloOpcode::kParameter &&
instruction.operand(1)->opcode() == HloOpcode::kParameter &&
instruction.operand(0) != instruction.operand(1);
}
// Given an instruction and operand number, replace the given operand with
// a Literal Constant Zero. Handle the case of a fusion instruction by
// replacing the fusion's parent's parameter with a Literal Constant Zero,
// unless the fusion's parent is itself a fusion.
Status MaybeReplaceParameterInputWithZero(HloInstruction* const instruction,
const int64 operand_number) {
CHECK_LT(operand_number, instruction->operand_count());
if (instruction->operand(operand_number)->opcode() != HloOpcode::kParameter) {
return Status::OK();
}
HloComputation* const computation = instruction->parent();
std::unique_ptr<HloInstruction> zero = HloInstruction::CreateConstant(
MakeUnique<Literal>(Literal::Zero(instruction->shape().element_type())));
if (computation->IsFusionComputation()) {
HloInstruction* const fusion_instruction = computation->FusionInstruction();
if (fusion_instruction->IsFused()) {
return Unimplemented(
"Unable to replace fused parameter of fusion instruction");
}
TF_RETURN_IF_ERROR(fusion_instruction->ReplaceOperandWith(
instruction->operand(operand_number)->parameter_number(),
fusion_instruction->parent()->AddInstruction(std::move(zero))));
} else {
TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(
operand_number, computation->AddInstruction(std::move(zero))));
}
return Status::OK();
}
} // namespace
StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape) {
@ -117,4 +156,32 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
return std::move(arguments);
}
Status ReplaceInitsWithConstants(HloModule* const module) {
for (HloComputation* const computation : module->computations()) {
for (HloInstruction* const instruction : computation->instructions()) {
const HloOpcode opcode = instruction->opcode();
if ((opcode == HloOpcode::kReduce ||
opcode == HloOpcode::kReduceWindow) &&
LooksLikeSum(*instruction->to_apply()->root_instruction())) {
TF_RETURN_IF_ERROR(MaybeReplaceParameterInputWithZero(instruction, 1));
} else if (opcode == HloOpcode::kSelectAndScatter &&
LooksLikeSum(*instruction->scatter()->root_instruction())) {
TF_RETURN_IF_ERROR(MaybeReplaceParameterInputWithZero(instruction, 2));
}
}
}
return Status::OK();
}
Status VerifyHloModule(const perftools::gputools::Platform& platform,
HloModule* const module) {
return HloVerifier(
std::bind(
&TransferManager::GetByteSizeRequirement,
TransferManager::GetForPlatform(&platform).ConsumeValueOrDie(),
std::placeholders::_1))
.Run(module)
.status();
}
} // namespace xla

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/stream_executor/platform.h"
namespace xla {
@ -62,6 +63,16 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape);
StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
const HloModule& module);
// Reductions using Adds, ReduceWindow, and SelectAndScatter, require their
// init_value to be replaced with the constant 0.0f when testing, otherwise we
// may generate a bad init_value when looking at the op in isolation.
Status ReplaceInitsWithConstants(HloModule* const module);
// Check that a given module satisfies various constraints before trying to
// execute it.
Status VerifyHloModule(const perftools::gputools::Platform& platform,
HloModule* const module);
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_UTILS_H_

View File

@ -776,11 +776,32 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
shape, *fusion_kind, operands, *fusion_computation));
break;
}
case HloOpcode::kInfeed: {
optional<string> config;
attrs["infeed_config"] = {/*required=*/false, AttrTy::kString, &config};
if (!ParseOperands(&operands, /*expected_size=*/0) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateInfeed(shape, config ? *config : ""));
break;
}
case HloOpcode::kOutfeed: {
optional<string> config;
attrs["outfeed_config"] = {/*required=*/false, AttrTy::kString, &config};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateOutfeed(
shape, operands[0], config ? *config : ""));
break;
}
case HloOpcode::kConditional:
case HloOpcode::kCustomCall:
case HloOpcode::kReducePrecision:
case HloOpcode::kRng:
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kTrace:
return TokenError(StrCat("parsing not yet implemented for op: ",
HloOpcodeString(opcode)));
@ -805,7 +826,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
instruction->set_metadata(*metadata);
}
return AddInstruction(name, instruction);
}
} // NOLINT(readability/fn_size)
// ::= '{' (single_sharding | tuple_sharding) '}'
//

View File

@ -560,6 +560,20 @@ ENTRY %fusion.v3 () -> f32[3,2,1,1] {
ROOT %fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %constant, f32[2]{0} %constant.1), kind=kLoop, calls=%fused_computation
}
)"
},
// infeed/outfeed
{
"InfeedOutfeed",
R"(HloModule outfeed_module:
ENTRY %InfeedToOutfeed () -> (u32[3], pred[]) {
%infeed = (u32[3]{0}, pred[]) infeed()
%outfeed = () outfeed((u32[3]{0}, pred[]) %infeed)
ROOT %infeed.1 = (u32[3]{0}, pred[]) infeed()
%outfeed.1 = () outfeed((u32[3]{0}, pred[]) %infeed.1)
}
)"
}
});
@ -866,7 +880,7 @@ TEST_F(HloParserTest, CommaBetweenSubAttributes) {
const string original = R"(HloModule test_comma_module:
ENTRY %test_comma.v4 () -> f32[] {
ROOT %constant = f32[] constant(-4.2), metadata={source_line=5, op_type="const"}
ROOT %constant = f32[] constant(-4.2), metadata={source_line=5, op_type="::const"}
}
)";

View File

@ -17,3 +17,5 @@ def xla_proto_library(name, srcs=[], deps=[], visibility=None, testonly=0):
protoc="@protobuf_archive//:protoc",
testonly=testonly,
visibility=visibility,)
ORC_JIT_MEMORY_MAPPER_TARGETS = []

View File

@ -19,6 +19,7 @@ py_library(
srcs = ["__init__.py"] + glob(["python/ops/*.py"]),
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/distributions:distributions_py",
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
@ -32,7 +33,6 @@ py_library(
"//tensorflow/python:random_ops",
"//tensorflow/python:state_ops",
"//tensorflow/python:util",
"//tensorflow/python/ops/distributions",
"//third_party/py/numpy",
],
)
@ -99,6 +99,25 @@ cuda_py_test(
],
)
cuda_py_test(
name = "layers_dense_variational_test",
size = "small",
srcs = ["python/kernel_tests/layers_dense_variational_test.py"],
additional_deps = [
":bayesflow_py",
"//third_party/py/numpy",
"//tensorflow/contrib/distributions:distributions_py",
"//tensorflow/python/ops/distributions",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:gradients",
"//tensorflow/python:linalg_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn_ops",
],
)
cuda_py_test(
name = "monte_carlo_test",
size = "small",
@ -160,6 +179,27 @@ cuda_py_test(
],
)
cuda_py_test(
name = "sgld_optimizer_test",
size = "small",
srcs = ["python/kernel_tests/sgld_optimizer_test.py"],
additional_deps = [
":bayesflow_py",
"//third_party/py/numpy",
"//tensorflow/contrib/distributions:distributions_py",
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/python/ops/distributions",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:gradients",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:random_seed",
],
)
filegroup(
name = "all_files",
srcs = glob(

View File

@ -25,16 +25,28 @@ from tensorflow.contrib.bayesflow.python.ops import csiszar_divergence
from tensorflow.contrib.bayesflow.python.ops import custom_grad
from tensorflow.contrib.bayesflow.python.ops import halton_sequence
from tensorflow.contrib.bayesflow.python.ops import hmc
from tensorflow.contrib.bayesflow.python.ops import layers
from tensorflow.contrib.bayesflow.python.ops import metropolis_hastings
from tensorflow.contrib.bayesflow.python.ops import monte_carlo
from tensorflow.contrib.bayesflow.python.ops import optimizers
# pylint: enable=unused-import,line-too-long
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = ['csiszar_divergence', 'custom_grad', 'entropy',
'metropolis_hastings', 'monte_carlo', 'halton_sequence',
'hmc', 'special_math', 'stochastic_variables',
'variational_inference']
_allowed_symbols = [
'csiszar_divergence',
'custom_grad',
'entropy',
'halton_sequence',
'hmc',
'layers',
'metropolis_hastings',
'monte_carlo',
'optimizers',
'special_math',
'stochastic_variables',
'variational_inference',
]
remove_undocumented(__name__, _allowed_symbols)

View File

@ -0,0 +1,304 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for dense Bayesian layers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.bayesflow.python.ops import layers_dense_variational_impl as prob_layers_lib
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import normal as normal_lib
from tensorflow.python.platform import test
class Counter(object):
"""Helper class to manage incrementing a counting `int`."""
def __init__(self):
self._value = -1
@property
def value(self):
return self._value
def __call__(self):
self._value += 1
return self._value
class MockDistribution(normal_lib.Normal):
"""Monitors DenseVariational calls to the underlying distribution."""
def __init__(self, result_sample, result_log_prob, loc=None, scale=None):
self.result_sample = result_sample
self.result_log_prob = result_log_prob
self.result_loc = loc
self.result_scale = scale
self.called_log_prob = Counter()
self.called_sample = Counter()
self.called_loc = Counter()
self.called_scale = Counter()
def log_prob(self, *args, **kwargs):
self.called_log_prob()
return self.result_log_prob
def sample(self, *args, **kwargs):
self.called_sample()
return self.result_sample
@property
def loc(self):
self.called_loc()
return self.result_loc
@property
def scale(self):
self.called_scale()
return self.result_scale
class MockKLDivergence(object):
"""Monitors DenseVariational calls to the divergence implementation."""
def __init__(self, result):
self.result = result
self.args = []
self.called = Counter()
def __call__(self, *args, **kwargs):
self.called()
self.args.append(args)
return self.result
class DenseVariationalLocalReparametrization(test.TestCase):
def testKLPenaltyKernel(self):
with self.test_session():
dense_vi = prob_layers_lib.DenseVariational(units=2)
inputs = random_ops.random_uniform([2, 3], seed=1)
# No keys.
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 0)
self.assertListEqual(dense_vi.losses, loss_keys)
_ = dense_vi(inputs)
# Yes keys.
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
self.assertListEqual(dense_vi.losses, loss_keys)
def testKLPenaltyBoth(self):
def _make_normal(dtype, *args): # pylint: disable=unused-argument
return normal_lib.Normal(
loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.))
with self.test_session():
dense_vi = prob_layers_lib.DenseVariational(
units=2,
bias_posterior_fn=prob_layers_lib.default_mean_field_normal_fn(),
bias_prior_fn=_make_normal)
inputs = random_ops.random_uniform([2, 3], seed=1)
# No keys.
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 0)
self.assertListEqual(dense_vi.losses, loss_keys)
_ = dense_vi(inputs)
# Yes keys.
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 2)
self.assertListEqual(dense_vi.losses, loss_keys)
def testVariationalNonLocal(self):
batch_size, in_size, out_size = 2, 3, 4
with self.test_session() as sess:
seed = Counter()
inputs = random_ops.random_uniform([batch_size, in_size], seed=seed())
kernel_size = [in_size, out_size]
kernel_posterior = MockDistribution(
result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()),
result_sample=random_ops.random_uniform(kernel_size, seed=seed()))
kernel_prior = MockDistribution(
result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()),
result_sample=random_ops.random_uniform(kernel_size, seed=seed()))
kernel_divergence = MockKLDivergence(
result=random_ops.random_uniform(kernel_size, seed=seed()))
bias_size = [out_size]
bias_posterior = MockDistribution(
result_log_prob=random_ops.random_uniform(bias_size, seed=seed()),
result_sample=random_ops.random_uniform(bias_size, seed=seed()))
bias_prior = MockDistribution(
result_log_prob=random_ops.random_uniform(bias_size, seed=seed()),
result_sample=random_ops.random_uniform(bias_size, seed=seed()))
bias_divergence = MockKLDivergence(
result=random_ops.random_uniform(bias_size, seed=seed()))
expected_outputs = (
math_ops.matmul(inputs, kernel_posterior.result_sample) +
bias_posterior.result_sample)
dense_vi = prob_layers_lib.DenseVariational(
units=2,
kernel_use_local_reparameterization=False,
kernel_posterior_fn=lambda *args: kernel_posterior,
kernel_posterior_tensor_fn=lambda d: d.sample(seed=42),
kernel_prior_fn=lambda *args: kernel_prior,
kernel_divergence_fn=kernel_divergence,
bias_posterior_fn=lambda *args: bias_posterior,
bias_posterior_tensor_fn=lambda d: d.sample(seed=43),
bias_prior_fn=lambda *args: bias_prior,
bias_divergence_fn=bias_divergence)
outputs = dense_vi(inputs)
kl_penalty = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
[
expected_outputs_, actual_outputs_,
expected_kernel_, actual_kernel_,
expected_kernel_divergence_, actual_kernel_divergence_,
expected_bias_, actual_bias_,
expected_bias_divergence_, actual_bias_divergence_,
] = sess.run([
expected_outputs, outputs,
kernel_posterior.result_sample, dense_vi.kernel.posterior_tensor,
kernel_divergence.result, kl_penalty[0],
bias_posterior.result_sample, dense_vi.bias.posterior_tensor,
bias_divergence.result, kl_penalty[1],
])
self.assertAllClose(
expected_kernel_, actual_kernel_,
rtol=1e-6, atol=0.)
self.assertAllClose(
expected_bias_, actual_bias_,
rtol=1e-6, atol=0.)
self.assertAllClose(
expected_outputs_, actual_outputs_,
rtol=1e-6, atol=0.)
self.assertAllClose(
expected_kernel_divergence_, actual_kernel_divergence_,
rtol=1e-6, atol=0.)
self.assertAllClose(
expected_bias_divergence_, actual_bias_divergence_,
rtol=1e-6, atol=0.)
self.assertAllEqual(
[[kernel_posterior, kernel_prior, kernel_posterior.result_sample]],
kernel_divergence.args)
self.assertAllEqual(
[[bias_posterior, bias_prior, bias_posterior.result_sample]],
bias_divergence.args)
def testVariationalLocal(self):
batch_size, in_size, out_size = 2, 3, 4
with self.test_session() as sess:
seed = Counter()
inputs = random_ops.random_uniform([batch_size, in_size], seed=seed())
kernel_size = [in_size, out_size]
kernel_posterior = MockDistribution(
loc=random_ops.random_uniform(kernel_size, seed=seed()),
scale=random_ops.random_uniform(kernel_size, seed=seed()),
result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()),
result_sample=random_ops.random_uniform(kernel_size, seed=seed()))
kernel_prior = MockDistribution(
result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()),
result_sample=random_ops.random_uniform(kernel_size, seed=seed()))
kernel_divergence = MockKLDivergence(
result=random_ops.random_uniform(kernel_size, seed=seed()))
bias_size = [out_size]
bias_posterior = MockDistribution(
result_log_prob=random_ops.random_uniform(bias_size, seed=seed()),
result_sample=random_ops.random_uniform(bias_size, seed=seed()))
bias_prior = MockDistribution(
result_log_prob=random_ops.random_uniform(bias_size, seed=seed()),
result_sample=random_ops.random_uniform(bias_size, seed=seed()))
bias_divergence = MockKLDivergence(
result=random_ops.random_uniform(bias_size, seed=seed()))
expected_kernel_posterior_affine = normal_lib.Normal(
loc=math_ops.matmul(inputs, kernel_posterior.result_loc),
scale=math_ops.matmul(
inputs**2., kernel_posterior.result_scale**2)**0.5)
expected_kernel_posterior_affine_tensor = (
expected_kernel_posterior_affine.sample(seed=42))
expected_outputs = (expected_kernel_posterior_affine_tensor +
bias_posterior.result_sample)
dense_vi = prob_layers_lib.DenseVariational(
units=2,
kernel_use_local_reparameterization=True,
kernel_posterior_fn=lambda *args: kernel_posterior,
kernel_posterior_tensor_fn=lambda d: d.sample(seed=42),
kernel_prior_fn=lambda *args: kernel_prior,
kernel_divergence_fn=kernel_divergence,
bias_posterior_fn=lambda *args: bias_posterior,
bias_posterior_tensor_fn=lambda d: d.sample(seed=43),
bias_prior_fn=lambda *args: bias_prior,
bias_divergence_fn=bias_divergence)
outputs = dense_vi(inputs)
kl_penalty = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
[
expected_outputs_, actual_outputs_,
expected_kernel_divergence_, actual_kernel_divergence_,
expected_bias_, actual_bias_,
expected_bias_divergence_, actual_bias_divergence_,
] = sess.run([
expected_outputs, outputs,
kernel_divergence.result, kl_penalty[0],
bias_posterior.result_sample, dense_vi.bias.posterior_tensor,
bias_divergence.result, kl_penalty[1],
])
self.assertAllClose(
expected_bias_, actual_bias_,
rtol=1e-6, atol=0.)
self.assertAllClose(
expected_outputs_, actual_outputs_,
rtol=1e-6, atol=0.)
self.assertAllClose(
expected_kernel_divergence_, actual_kernel_divergence_,
rtol=1e-6, atol=0.)
self.assertAllClose(
expected_bias_divergence_, actual_bias_divergence_,
rtol=1e-6, atol=0.)
self.assertAllEqual(
[[kernel_posterior, kernel_prior, None]],
kernel_divergence.args)
self.assertAllEqual(
[[bias_posterior, bias_prior, bias_posterior.result_sample]],
bias_divergence.args)
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,209 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functional test for GradientDescent."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
from tensorflow.contrib.bayesflow.python.ops.optimizers import SGLDOptimizer
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
class SGLDOptimizerTest(test.TestCase):
def testBasic(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
with self.test_session():
var0 = variables.Variable([1.1, 2.1], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
decay_rate = 0.53
sgd_op = SGLDOptimizer(
3.0, preconditioner_decay_rate=decay_rate).apply_gradients(
zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
# Fetch params to validate initial values
self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval())
self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval())
# Run 1 step of sgd
sgd_op.run()
# Validate updated params
grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate +
(1 - decay_rate) * 0.1**2 + 1e-8))
self.assertAllCloseAccordingToType(
[1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], var0.eval())
grads_scaled = (0.5 * 0.01 / math.sqrt(
decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8))
self.assertAllCloseAccordingToType(
[3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], var1.eval())
def testBasicMultiInstance(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
with self.test_session():
var0 = variables.Variable([1.1, 2.1], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
vara = variables.Variable([1.1, 2.1], dtype=dtype)
varb = variables.Variable([3.0, 4.0], dtype=dtype)
gradsa = constant_op.constant([0.1, 0.1], dtype=dtype)
gradsb = constant_op.constant([0.01, 0.01], dtype=dtype)
decay_rate = 0.5
sgd_optimizer = SGLDOptimizer(3.0, preconditioner_decay_rate=decay_rate)
sgd_op = sgd_optimizer.apply_gradients(
zip([grads0, grads1], [var0, var1]))
sgd_optimizer2 = SGLDOptimizer(
3.0, preconditioner_decay_rate=decay_rate)
sgd_op2 = sgd_optimizer2.apply_gradients(
zip([gradsa, gradsb], [vara, varb]))
variables.global_variables_initializer().run()
# Fetch params to validate initial values
self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval())
self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval())
self.assertAllCloseAccordingToType([1.1, 2.1], vara.eval())
self.assertAllCloseAccordingToType([3.0, 4.0], varb.eval())
# Run 1 step of sgd
sgd_op.run()
sgd_op2.run()
# Validate updated params
grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate +
(1 - decay_rate) * 0.1**2 + 1e-8))
self.assertAllCloseAccordingToType(
[1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], var0.eval())
self.assertAllCloseAccordingToType(
[1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], vara.eval())
grads_scaled = (0.5 * 0.01 / math.sqrt(
decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8))
self.assertAllCloseAccordingToType(
[3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], var1.eval())
self.assertAllCloseAccordingToType(
[3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], varb.eval())
self.assertNotEqual(sgd_optimizer.variable_scope,
sgd_optimizer2.variable_scope)
self.assertNotEqual(sgd_optimizer.variable_scope.name,
sgd_optimizer2.variable_scope.name)
def testTensorLearningRate(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
with self.test_session():
var0 = variables.Variable([1.1, 2.1], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
lrate = constant_op.constant(3.0)
decay_rate = 0.5
sgd_op = SGLDOptimizer(
lrate, preconditioner_decay_rate=constant_op.constant(
decay_rate)).apply_gradients(
zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
# Fetch params to validate initial values
self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval())
self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval())
# Run 1 step of sgd
sgd_op.run()
# Validate updated params
grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate +
(1 - decay_rate) * 0.1**2 + 1e-8))
self.assertAllCloseAccordingToType(
[1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], var0.eval())
grads_scaled = (0.5 * 0.01 / math.sqrt(
decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8))
self.assertAllCloseAccordingToType(
[3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], var1.eval())
def testGradWrtRef(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
with self.test_session():
opt = SGLDOptimizer(3.0)
values = [1.0, 3.0]
vars_ = [variables.Variable([v], dtype=dtype) for v in values]
grads_and_vars = opt.compute_gradients(vars_[0] + vars_[1], vars_)
variables.global_variables_initializer().run()
for grad, _ in grads_and_vars:
self.assertAllCloseAccordingToType([1.0], grad.eval())
def testWithGlobalStep(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
with self.test_session():
global_step = variables.Variable(0, trainable=False)
var0 = variables.Variable([1.1, 2.1], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
decay_rate = 0.1
sgd_op = SGLDOptimizer(
3.0, preconditioner_decay_rate=decay_rate).apply_gradients(
zip([grads0, grads1], [var0, var1]), global_step=global_step)
variables.global_variables_initializer().run()
# Fetch params to validate initial values
self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval())
self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval())
# Run 1 step of sgd
sgd_op.run()
# Validate updated params and global_step
grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate +
(1 - decay_rate) * 0.1**2 + 1e-8))
self.assertAllCloseAccordingToType(
[1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], var0.eval())
grads_scaled = (0.5 * 0.01 / math.sqrt(
decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8))
self.assertAllCloseAccordingToType(
[3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], var1.eval())
self.assertAllCloseAccordingToType(1, global_step.eval())
def testSparseBasic(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
with self.test_session():
var0 = variables.Variable([[1.1], [2.1]], dtype=dtype)
var1 = variables.Variable([[3.0], [4.0]], dtype=dtype)
grads0 = ops.IndexedSlices(
constant_op.constant([0.1], shape=[1, 1], dtype=dtype),
constant_op.constant([0]), constant_op.constant([2, 1]))
grads1 = ops.IndexedSlices(
constant_op.constant([0.01], shape=[1, 1], dtype=dtype),
constant_op.constant([1]), constant_op.constant([2, 1]))
decay_rate = 0.9
sgd_op = SGLDOptimizer(
3.0, preconditioner_decay_rate=decay_rate).apply_gradients(
zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
# Fetch params to validate initial values
self.assertAllCloseAccordingToType([[1.1], [2.1]], var0.eval())
self.assertAllCloseAccordingToType([[3.0], [4.0]], var1.eval())
# Run 1 step of sgd
sgd_op.run()
# Validate updated params
grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate +
(1 - decay_rate) * 0.1**2 + 1e-8))
self.assertAllCloseAccordingToType([[1.1 - 3.0 * grads_scaled], [2.1]],
var0.eval())
grads_scaled = (0.5 * 0.01 / math.sqrt(
decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8))
self.assertAllCloseAccordingToType(
[[3.0 - 3.0 * 0], [4.0 - 3.0 * grads_scaled]], var1.eval())
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,37 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Probabilistic neural layers.
See ${python/contrib.bayesflow.layers}.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.contrib.bayesflow.python.ops.layers_dense_variational_impl import *
# pylint: enable=wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
'DenseVariational',
'dense_variational',
'default_loc_scale_fn',
'default_mean_field_normal_fn',
]
remove_undocumented(__name__, _allowed_symbols)

View File

@ -0,0 +1,797 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Dense Bayesian layer using KL-divergence based variational inference.
@@DenseVariational
@@dense_variational
@@default_loc_scale_fn
@@default_mean_field_normal_fn
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.contrib.distributions.python.ops import deterministic as deterministic_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.layers import base as layers_lib
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import standard_ops
from tensorflow.python.ops.distributions import kullback_leibler as kl_lib
from tensorflow.python.ops.distributions import normal as normal_lib
__all__ = [
"DenseVariational",
"dense_variational",
"default_loc_scale_fn",
"default_mean_field_normal_fn",
]
def default_loc_scale_fn(
is_singular=False,
loc_initializer=init_ops.random_normal_initializer(stddev=0.1),
untransformed_scale_initializer=init_ops.random_normal_initializer(
mean=-3., stddev=0.1),
loc_regularizer=None,
untransformed_scale_regularizer=None,
loc_constraint=None,
untransformed_scale_constraint=None):
"""Makes closure which creates `loc`, `scale` params from `tf.get_variable`.
This function produces a closure which produces `loc`, `scale` using
`tf.get_variable`. The closure accepts the following arguments:
dtype: Type of parameter's event.
shape: Python `list`-like representing the parameter's event shape.
name: Python `str` name prepended to any created (or existing)
`tf.Variable`s.
trainable: Python `bool` indicating all created `tf.Variable`s should be
added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`.
add_variable_fn: `tf.get_variable`-like `callable` used to create (or
access existing) `tf.Variable`s.
Args:
is_singular: Python `bool` indicating if `scale is None`. Default: `False`.
loc_initializer: Initializer function for the `loc` parameters.
The default is `tf.random_normal_initializer(mean=0., stddev=0.1)`.
untransformed_scale_initializer: Initializer function for the `scale`
parameters. Default value: `tf.random_normal_initializer(mean=-3.,
stddev=0.1)`. This implies the softplus transformed result has mean
approximately `0.05` and std. deviation approximately `0.005`.
loc_regularizer: Regularizer function for the `loc` parameters.
The default (`None`) is to use the `tf.get_variable` default.
untransformed_scale_regularizer: Regularizer function for the `scale`
parameters. The default (`None`) is to use the `tf.get_variable` default.
loc_constraint: An optional projection function to be applied to the
loc after being updated by an `Optimizer`. The function must take as input
the unprojected variable and must return the projected variable (which
must have the same shape). Constraints are not safe to use when doing
asynchronous distributed training.
The default (`None`) is to use the `tf.get_variable` default.
untransformed_scale_constraint: An optional projection function to be
applied to the `scale` parameters after being updated by an `Optimizer`
(e.g. used to implement norm constraints or value constraints). The
function must take as input the unprojected variable and must return the
projected variable (which must have the same shape). Constraints are not
safe to use when doing asynchronous distributed training. The default
(`None`) is to use the `tf.get_variable` default.
Returns:
default_loc_scale_fn: Python `callable` which instantiates `loc`, `scale`
parameters from args: `dtype, shape, name, trainable, add_variable_fn`.
"""
def _fn(dtype, shape, name, trainable, add_variable_fn):
"""Creates `loc`, `scale` parameters."""
loc = add_variable_fn(
name=name + "_loc",
shape=shape,
initializer=loc_initializer,
regularizer=loc_regularizer,
constraint=loc_constraint,
dtype=dtype,
trainable=trainable)
if is_singular:
return loc, None
untransformed_scale = add_variable_fn(
name=name + "_untransformed_scale",
shape=shape,
initializer=untransformed_scale_initializer,
regularizer=untransformed_scale_regularizer,
constraint=untransformed_scale_constraint,
dtype=dtype,
trainable=trainable)
scale = (np.finfo(dtype.as_numpy_dtype).eps +
nn_ops.softplus(untransformed_scale))
return loc, scale
return _fn
def default_mean_field_normal_fn(
is_singular=False,
loc_initializer=None,
untransformed_scale_initializer=None,
loc_regularizer=None,
untransformed_scale_regularizer=None,
loc_constraint=None,
untransformed_scale_constraint=None):
"""Creates a function to build Normal distributions with trainable params.
This function produces a closure which produces `tf.distributions.Normal`
parameterized by a loc` and `scale` each created using `tf.get_variable`. The
produced closure accepts the following arguments:
name: Python `str` name prepended to any created (or existing)
`tf.Variable`s.
shape: Python `list`-like representing the parameter's event shape.
dtype: Type of parameter's event.
trainable: Python `bool` indicating all created `tf.Variable`s should be
added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`.
add_variable_fn: `tf.get_variable`-like `callable` used to create (or
access existing) `tf.Variable`s.
Args:
is_singular: Python `bool` if `True`, forces the special case limit of
`scale->0`, i.e., a `Deterministic` distribution.
loc_initializer: Initializer function for the `loc` parameters.
If `None` (default), values are initialized using the default
initializer used by `tf.get_variable`.
untransformed_scale_initializer: Initializer function for the `scale`
parameters. If `None` (default), values are initialized using the default
initializer used by `tf.get_variable`.
loc_regularizer: Regularizer function for the `loc` parameters.
untransformed_scale_regularizer: Regularizer function for the `scale`
parameters.
loc_constraint: An optional projection function to be applied to the
loc after being updated by an `Optimizer`. The function must take as input
the unprojected variable and must return the projected variable (which
must have the same shape). Constraints are not safe to use when doing
asynchronous distributed training.
untransformed_scale_constraint: An optional projection function to be
applied to the `scale` parameters after being updated by an `Optimizer`
(e.g. used to implement norm constraints or value constraints). The
function must take as input the unprojected variable and must return the
projected variable (which must have the same shape). Constraints are not
safe to use when doing asynchronous distributed training.
Returns:
make_normal_fn: Python `callable` which creates a `tf.distributions.Normal`
using from args: `dtype, shape, name, trainable, add_variable_fn`.
"""
loc_scale_fn_ = default_loc_scale_fn(
is_singular,
loc_initializer,
untransformed_scale_initializer,
loc_regularizer,
untransformed_scale_regularizer,
loc_constraint,
untransformed_scale_constraint)
def _fn(dtype, shape, name, trainable, add_variable_fn):
"""Creates a batch of `Deterministic` or `Normal` distributions."""
loc, scale = loc_scale_fn_(dtype, shape, name, trainable, add_variable_fn)
if scale is None:
return deterministic_lib.Deterministic(loc=loc)
return normal_lib.Normal(loc=loc, scale=scale)
return _fn
class DenseVariational(layers_lib.Layer):
"""Densely-connected variational class.
This layer implements the Bayesian variational inference analogue to:
`outputs = activation(matmul(inputs, kernel) + bias)`
by assuming the `kernel` and/or the `bias` are random variables.
The layer implements a stochastic dense calculation by making a Monte Carlo
approximation of a [variational Bayesian method based on KL divergence](
https://en.wikipedia.org/wiki/Variational_Bayesian_methods), i.e.,
```none
-log p(y|x) = -log int_{R**d} p(y|x,w) p(w) dw
= -log int_{R**d} p(y,w|x) q(w|x) / q(w|x) dw
<= E_q(W|x)[-log p(y,W|x) + log q(W|x)] # Jensen's
= E_q(W|x)[-log p(y|x,W)] + KL[q(W|x), p(W)]
~= m**-1 sum{ -log(y|x,w[j]) : w[j] ~ q(W|x), j=1..m }
+ KL[q(W|x), p(W)]
```
where `W` denotes the (independent) `kernel` and `bias` random variables, `w`
is a random variate or outcome of `W`, `y` is the label, `x` is the evidence`,
and `~=` denotes an approximation which becomes exact as `m->inf`. The above
bound is sometimes referred to as the negative Evidence Lower BOund or
negative [ELBO](https://arxiv.org/abs/1601.00670). In context of a DNN, this
layer is appropriate to use when the final loss is a negative log-likelihood.
The Monte-Carlo sum portion is used for the feed-forward calculation of the
DNN. The KL divergence portion can be added to the final loss via:
`loss += sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))`.
The arguments permit separate specification of the surrogate posterior
(`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
random variables (which together comprise `W`).
Args:
units: Integer or Long, dimensionality of the output space.
activation: Activation function (`callable`). Set it to None to maintain a
linear activation.
activity_regularizer: Regularizer function for the output.
trainable: Boolean, if `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
kernel_use_local_reparameterization: Python `bool` indicating whether
`kernel` calculation should employ the Local Reparameterization Trick.
When `True`, `kernel_posterior_fn` must create an instance of
`tf.distributions.Normal`.
kernel_posterior_fn: Python `callable` which creates
`tf.distributions.Distribution` instance representing the surrogate
posterior of the `kernel` parameter. Default value:
`default_mean_field_normal_fn()`.
kernel_posterior_tensor_fn: Python `callable` which takes a
`tf.distributions.Distribution` instance and returns a representative
value. Default value: `lambda d: d.sample()`.
kernel_prior_fn: Python `callable` which creates `tf.distributions`
instance. See `default_mean_field_normal_fn` docstring for required
parameter signature.
Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
kernel_divergence_fn: Python `callable` which takes the surrogate posterior
distribution, prior distribution and random variate sample(s) from the
surrogate posterior and computes or approximates the KL divergence. The
distributions are `tf.distributions.Distribution`-like instances and the
sample is a `Tensor`.
bias_posterior_fn: Python `callable` which creates
`tf.distributions.Distribution` instance representing the surrogate
posterior of the `bias` parameter. Default value:
`default_mean_field_normal_fn(is_singular=True)` (which creates an
instance of `tf.distributions.Deterministic`).
bias_posterior_tensor_fn: Python `callable` which takes a
`tf.distributions.Distribution` instance and returns a representative
value. Default value: `lambda d: d.sample()`.
bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
See `default_mean_field_normal_fn` docstring for required parameter
signature. Default value: `None` (no prior, no variational inference)
bias_divergence_fn: Python `callable` which takes the surrogate posterior
distribution, prior distribution and random variate sample(s) from the
surrogate posterior and computes or approximates the KL divergence. The
distributions are `tf.distributions.Distribution`-like instances and the
sample is a `Tensor`.
name: Python `str`, the name of the layer. Layers with the same name will
share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in
such cases.
reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous
layer by the same name.
Properties:
units: Python integer, dimensionality of the output space.
activation: Activation function (`callable`).
activity_regularizer: Regularizer function for the output.
kernel_use_local_reparameterization: Python `bool` indicating whether
`kernel` calculation should employ the Local Reparameterization Trick.
kernel: `VariationalKernelParamater` instance containing all `kernel`
related properties and `callable`s.
bias: `VariationalParameter` instance containing all `kernel`
related properties and `callable`s.
"""
def __init__(
self,
units,
activation=None,
activity_regularizer=None,
trainable=True,
kernel_use_local_reparameterization=True,
kernel_posterior_fn=default_mean_field_normal_fn(),
kernel_posterior_tensor_fn=lambda d: d.sample(),
kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda
loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)),
kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
bias_posterior_fn=default_mean_field_normal_fn(is_singular=True),
bias_posterior_tensor_fn=lambda d: d.sample(),
bias_prior_fn=None,
bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
name=None,
**kwargs):
super(DenseVariational, self).__init__(
trainable=trainable,
name=name,
activity_regularizer=activity_regularizer,
**kwargs)
self._units = units
self._activation = activation
self._input_spec = layers_lib.InputSpec(min_ndim=2)
self._kernel_use_local_reparameterization = (
kernel_use_local_reparameterization)
self._kernel = VariationalKernelParameter(
kernel_posterior_fn,
kernel_posterior_tensor_fn,
kernel_prior_fn,
kernel_divergence_fn)
self._bias = VariationalParameter(
bias_posterior_fn,
bias_posterior_tensor_fn,
bias_prior_fn,
bias_divergence_fn)
@property
def units(self):
return self._units
@property
def activation(self):
return self._activation
@property
def input_spec(self):
return self._input_spec
@input_spec.setter
def input_spec(self, value):
self._input_spec = value
@property
def kernel_use_local_reparameterization(self):
return self._kernel_use_local_reparameterization
@property
def kernel(self):
return self._kernel
@property
def bias(self):
return self._bias
def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape)
in_size = input_shape.with_rank_at_least(2)[-1].value
if in_size is None:
raise ValueError("The last dimension of the inputs to `Dense` "
"should be defined. Found `None`.")
self._input_spec = layers_lib.InputSpec(min_ndim=2, axes={-1: in_size})
dtype = dtypes.as_dtype(self.dtype)
# Must have a posterior kernel.
self.kernel.posterior = self.kernel.posterior_fn(
dtype, [in_size, self.units], "kernel_posterior",
self.trainable, self.add_variable)
if self.kernel.prior_fn is None:
self.kernel_prior = None
else:
self.kernel.prior = self.kernel.prior_fn(
dtype, [in_size, self.units], "kernel_prior",
self.trainable, self.add_variable)
self._built_kernel_divergence = False
if self.bias.posterior_fn is None:
self.bias.posterior = None
else:
self.bias.posterior = self.bias.posterior_fn(
dtype, [self.units], "bias_posterior",
self.trainable, self.add_variable)
if self.bias.prior_fn is None:
self.bias.prior = None
else:
self.bias.prior = self.bias.prior_fn(
dtype, [self.units], "bias_prior",
self.trainable, self.add_variable)
self._built_bias_divergence = False
self.built = True
def call(self, inputs):
inputs = ops.convert_to_tensor(inputs, dtype=self.dtype)
outputs = self._apply_variational_kernel(inputs)
outputs = self._apply_variational_bias(outputs)
if self.activation is not None:
outputs = self.activation(outputs) # pylint: disable=not-callable
if not self._built_kernel_divergence:
self._apply_divergence(self.kernel, name="divergence_kernel")
self._built_kernel_divergence = True
if not self._built_bias_divergence:
self._apply_divergence(self.bias, name="divergence_bias")
self._built_bias_divergence = True
return outputs
def _apply_variational_kernel(self, inputs):
if not self.kernel_use_local_reparameterization:
self.kernel.posterior_tensor = self.kernel.posterior_tensor_fn(
self.kernel.posterior)
self.kernel.posterior_affine = None
self.kernel.posterior_affine_tensor = None
return self._matmul(inputs, self.kernel.posterior_tensor)
if not isinstance(self.kernel.posterior, normal_lib.Normal):
raise TypeError("`kernel_use_local_reparameterization=True` requires "
"`kernel_posterior_fn` produce an instance of "
"`tf.distributions.Normal` (saw: \"{}\").".format(
type(self.kernel.posterior).__name__))
self.kernel.posterior_affine = normal_lib.Normal(
loc=self._matmul(inputs, self.kernel.posterior.loc),
scale=standard_ops.sqrt(self._matmul(
standard_ops.square(inputs),
standard_ops.square(self.kernel.posterior.scale))))
self.kernel.posterior_affine_tensor = (
self.kernel.posterior_tensor_fn(self.kernel.posterior_affine))
self.kernel.posterior_tensor = None
return self.kernel.posterior_affine_tensor
def _apply_variational_bias(self, inputs):
if self.bias.posterior is None:
self.bias.posterior_tensor = None
return inputs
self.bias.posterior_tensor = self.bias.posterior_tensor_fn(
self.bias.posterior)
return nn.bias_add(inputs, self.bias.posterior_tensor)
def _apply_divergence(self, param, name):
if (param.divergence_fn is None or
param.posterior is None or
param.prior is None):
param.divergence = None
return
param.divergence = standard_ops.identity(
param.divergence_fn(
param.posterior, param.prior, param.posterior_tensor),
name=name)
self.add_loss(param.divergence)
def _matmul(self, inputs, kernel):
if inputs.shape.ndims <= 2:
return standard_ops.matmul(inputs, kernel)
# To handle broadcasting, we must use `tensordot`.
return standard_ops.tensordot(inputs, kernel, axes=[[-1], [0]])
def _compute_output_shape(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).with_rank_at_least(2)
if input_shape[-1].value is None:
raise ValueError(
"The innermost dimension of input_shape must be defined, "
"but saw: {}".format(input_shape))
return input_shape[:-1].concatenate(self.units)
def dense_variational(
inputs,
units,
activation=None,
activity_regularizer=None,
trainable=True,
kernel_use_local_reparameterization=True,
kernel_posterior_fn=default_mean_field_normal_fn(),
kernel_posterior_tensor_fn=lambda d: d.sample(),
kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda
loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)),
kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
bias_posterior_fn=default_mean_field_normal_fn(is_singular=True),
bias_posterior_tensor_fn=lambda d: d.sample(),
bias_prior_fn=None,
bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
name=None,
reuse=None):
"""Densely-connected variational layer.
This layer implements the Bayesian variational inference analogue to:
`outputs = activation(matmul(inputs, kernel) + bias)`
by assuming the `kernel` and/or the `bias` are random variables.
The layer implements a stochastic dense calculation by making a Monte Carlo
approximation of a [variational Bayesian method based on KL divergence](
https://en.wikipedia.org/wiki/Variational_Bayesian_methods), i.e.,
```none
-log p(y|x) = -log int_{R**d} p(y|x,w) p(w) dw
= -log int_{R**d} p(y,w|x) q(w|x) / q(w|x) dw
<= E_q(W|x)[-log p(y,W|x) + log q(W|x)] # Jensen's
= E_q(W|x)[-log p(y|x,W)] + KL[q(W|x), p(W)]
~= m**-1 sum{ -log(y|x,w[j]) : w[j] ~ q(W|x), j=1..m }
+ KL[q(W|x), p(W)]
```
where `W` denotes the (independent) `kernel` and `bias` random variables, `w`
is a random variate or outcome of `W`, `y` is the label, `x` is the evidence`,
and `~=` denotes an approximation which becomes exact as `m->inf`. The above
bound is sometimes referred to as the negative Evidence Lower BOund or
negative [ELBO](https://arxiv.org/abs/1601.00670). In context of a DNN, this
layer is appropriate to use when the final loss is a negative log-likelihood.
The Monte-Carlo sum portion is used for the feed-forward calculation of the
DNN. The KL divergence portion can be added to the final loss via:
`loss += sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))`.
The arguments permit separate specification of the surrogate posterior
(`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
random variables (which together comprise `W`).
Args:
inputs: Tensor input.
units: Integer or Long, dimensionality of the output space.
activation: Activation function (`callable`). Set it to None to maintain a
linear activation.
activity_regularizer: Regularizer function for the output.
trainable: Boolean, if `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
kernel_use_local_reparameterization: Python `bool` indicating whether
`kernel` calculation should employ the Local Reparameterization Trick.
When `True`, `kernel_posterior_fn` must create an instance of
`tf.distributions.Normal`.
kernel_posterior_fn: Python `callable` which creates
`tf.distributions.Distribution` instance representing the surrogate
posterior of the `kernel` parameter. Default value:
`default_mean_field_normal_fn()`.
kernel_posterior_tensor_fn: Python `callable` which takes a
`tf.distributions.Distribution` instance and returns a representative
value. Default value: `lambda d: d.sample()`.
kernel_prior_fn: Python `callable` which creates `tf.distributions`
instance. See `default_mean_field_normal_fn` docstring for required
parameter signature.
Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
kernel_divergence_fn: Python `callable` which takes the surrogate posterior
distribution, prior distribution and random variate sample(s) from the
surrogate posterior and computes or approximates the KL divergence. The
distributions are `tf.distributions.Distribution`-like instances and the
sample is a `Tensor`.
bias_posterior_fn: Python `callable` which creates
`tf.distributions.Distribution` instance representing the surrogate
posterior of the `bias` parameter. Default value:
`default_mean_field_normal_fn(is_singular=True)` (which creates an
instance of `tf.distributions.Deterministic`).
bias_posterior_tensor_fn: Python `callable` which takes a
`tf.distributions.Distribution` instance and returns a representative
value. Default value: `lambda d: d.sample()`.
bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
See `default_mean_field_normal_fn` docstring for required parameter
signature. Default value: `None` (no prior, no variational inference)
bias_divergence_fn: Python `callable` which takes the surrogate posterior
distribution, prior distribution and random variate sample(s) from the
surrogate posterior and computes or approximates the KL divergence. The
distributions are `tf.distributions.Distribution`-like instances and the
sample is a `Tensor`.
name: Python `str`, the name of the layer. Layers with the same name will
share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in
such cases.
reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous
layer by the same name.
Returns:
output: `Tensor` representing a the affine transformed input under a random
draw from the surrogate posterior distribution.
"""
layer = DenseVariational(
units,
activation=activation,
activity_regularizer=activity_regularizer,
trainable=trainable,
kernel_use_local_reparameterization=(
kernel_use_local_reparameterization),
kernel_posterior_fn=kernel_posterior_fn,
kernel_posterior_tensor_fn=kernel_posterior_tensor_fn,
kernel_prior_fn=kernel_prior_fn,
kernel_divergence_fn=kernel_divergence_fn,
bias_posterior_fn=bias_posterior_fn,
bias_posterior_tensor_fn=bias_posterior_tensor_fn,
bias_prior_fn=bias_prior_fn,
bias_divergence_fn=bias_divergence_fn,
name=name,
dtype=inputs.dtype.base_dtype,
_scope=name,
_reuse=reuse)
return layer.apply(inputs)
class NotSet(object):
"""Helper to track whether a `VariationalParameter` value has been set."""
pass
class VariationalParameter(object):
"""Struct-like container of variational parameter properties.
A `VariationalParameter` is intitialized with Python `callable`s which set the
value of correspondingly named members. Corresponding values have "set once"
semantics, i.e., once set to any value they are immutable.
"""
def __init__(
self,
posterior_fn,
posterior_tensor_fn,
prior_fn,
divergence_fn):
"""Creates the `VariationalParameter` struct-like object.
Args:
posterior_fn: Python `callable` which creates a
`tf.distribution.Distribution` like object representing the posterior
distribution. See `VariationalParameter.posterior_fn` for `callable`'s
required parameters.
posterior_tensor_fn: Python `callable` which computes a `Tensor`
which represents the `posterior`.
prior_fn: Python `callable` which creates a
`tf.distribution.Distribution` like object representing the prior
distribution. See `VariationalParameter.prior_fn` for `callable`'s
required parameters.
divergence_fn: Python `callable` which computes the KL divergence from
`posterior` to `prior`. See `VariationalParameter.divergence_fn` for
required `callable`'s parameters.
"""
self._posterior_fn = posterior_fn
self._posterior = NotSet()
self._posterior_tensor_fn = posterior_tensor_fn
self._posterior_tensor = NotSet()
self._prior_fn = prior_fn
self._prior = NotSet()
self._divergence_fn = divergence_fn
self._divergence = NotSet()
self._init_helper()
@property
def posterior_fn(self):
"""`callable` which creates `tf.distributions.Distribution`-like posterior.
The `callable` must accept the following parameters:
name: Python `str` name prepended to any created (or existing)
`tf.Variable`s.
shape: Python `list`-like representing the parameter's event shape.
dtype: Type of parameter's event.
trainable: Python `bool` indicating all created `tf.Variable`s should be
added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`.
add_variable_fn: `tf.get_variable`-like `callable` used to create (or
access existing) `tf.Variable`s.
Returns:
posterior_fn: The Python `callable` specified in `__init__`.
"""
return self._posterior_fn
@property
def posterior(self):
"""`tf.distributions.Distribution`-like instance representing posterior."""
return self._posterior
@posterior.setter
def posterior(self, value):
"""One-time setter of the `posterior` distribution."""
if not isinstance(self._posterior, NotSet):
raise ValueError("Cannot override already set attribute.")
self._posterior = value
@property
def posterior_tensor_fn(self):
"""Creates `Tensor` representing the `posterior` distribution.
The `callable` must accept the following parameters:
posterior: `tf.distributions.Distribution`-like instance.
Returns:
posterior_tensor_fn: The Python `callable` specified in
`__init__`.
"""
return self._posterior_tensor_fn
@property
def posterior_tensor(self):
"""`Tensor` representing the `posterior` distribution."""
return self._posterior_tensor
@posterior_tensor.setter
def posterior_tensor(self, value):
"""One-time setter of the `posterior_tensor`."""
if not isinstance(self._posterior_tensor, NotSet):
raise ValueError("Cannot override already set attribute.")
self._posterior_tensor = value
@property
def prior_fn(self):
"""`callable` which creates `tf.distributions.Distribution`-like prior.
The `callable` must accept the following parameters:
name: Python `str` name prepended to any created (or existing)
`tf.Variable`s.
shape: Python `list`-like representing the parameter's event shape.
dtype: Type of parameter's event.
trainable: Python `bool` indicating all created `tf.Variable`s should be
added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`.
add_variable_fn: `tf.get_variable`-like `callable` used to create (or
access existing) `tf.Variable`s.
Returns:
prior_fn: The Python `callable` specified in `__init__`.
"""
return self._prior_fn
@property
def prior(self):
"""`tf.distributions.Distribution`-like instance representing posterior."""
return self._prior
@prior.setter
def prior(self, value):
"""One-time setter of the `prior` distribution."""
if not isinstance(self._prior, NotSet):
raise ValueError("Cannot override already set attribute.")
self._prior = value
@property
def divergence_fn(self):
"""`callable` which computes KL-divergence `Tensor` from posterior to prior.
The `callable` must accept the following parameters:
posterior: `tf.distributions.Distribution`-like instance.
prior: `tf.distributions.Distribution`-like instance.
posterior_tensor: `Tensor` representing value of posterior.
Returns:
divergence_fn: The Python `callable` specified in `__init__`.
"""
return self._divergence_fn
@property
def divergence(self):
"""`Tensor` representing KL-divergence from posterior to prior."""
return self._divergence
@divergence.setter
def divergence(self, value):
"""One-time setter of the `divergence`."""
if not isinstance(self._divergence, NotSet):
raise ValueError("Cannot override already set attribute.")
self._divergence = value
def _init_helper(self):
pass
class VariationalKernelParameter(VariationalParameter):
"""Struct-like container of variational kernel properties.
A `VariationalKernelParameter` is intitialized with Python `callable`s which
set the value of correspondingly named members. Corresponding values have "set
once" semantics, i.e., once set to any value they are immutable.
"""
@property
def posterior_affine(self):
"""`tf.distributions.Distribution` affine transformed posterior."""
return self._posterior_affine
@posterior_affine.setter
def posterior_affine(self, value):
"""One-time setter of `posterior_affine`."""
if not isinstance(self._posterior_affine, NotSet):
raise ValueError("Cannot override already set attribute.")
self._posterior_affine = value
@property
def posterior_affine_tensor(self):
"""`Tensor` representing the `posterior_affine` distribution."""
return self._posterior_affine_tensor
@posterior_affine_tensor.setter
def posterior_affine_tensor(self, value):
"""One-time setter of the `posterior_affine_tensor`."""
if not isinstance(self._posterior_affine_tensor, NotSet):
raise ValueError("Cannot override already set attribute.")
self._posterior_affine_tensor = value
def _init_helper(self):
self._posterior_affine = NotSet()
self._posterior_affine_tensor = NotSet()

View File

@ -0,0 +1,34 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Probabilistic optimizer modules.
See ${python/contrib.bayesflow.optimizers}.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.contrib.bayesflow.python.ops.sgld_optimizer import *
# pylint: enable=wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
'SGLDOptimizer',
]
remove_undocumented(__name__, _allowed_symbols)

View File

@ -0,0 +1,216 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""An optimizer module for stochastic gradient Langevin dynamics."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope as varscope_ops
from tensorflow.python.training import optimizer
from tensorflow.python.training import training_ops
class SGLDOptimizer(optimizer.Optimizer):
"""An optimizer module for stochastic gradient Langevin dynamics.
This implements the preconditioned Stochastic Gradient Langevin Dynamics
optimizer [1]. The optimization variable is regarded as a sample from the
posterior under Stochastic Gradient Langevin Dynamics with noise rescaled in
each dimension according to RMSProp [2].
Note: If a prior is included in the loss, it should be scaled by
`1/num_pseudo_batches`, where num_pseudo_batches is the number of minibatches
in the data. I.e., it should be divided by the `num_pseudo_batches` term
described below.
[1]: "Preconditioned Stochastic Gradient Langevin Dynamics for Deep Neural
Networks." Chunyuan Li, Changyou Chen, David Carlson, Lawrence Carin.
ArXiv:1512.07666, 2015. https://arxiv.org/abs/1512.07666
[2]: http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf
Args:
learning_rate: Scalar `float`-like `Tensor`. The base learning rate for the
optimizer. Must be tuned to the specific function being minimized.
preconditioner_decay_rate: Scalar `float`-like `Tensor`. The exponential
decay rate of the rescaling of the preconditioner (RMSprop). (This is
"alpha" in [1]). Should be smaller than but nearly `1` to approximate
sampling from the posterior. (Default: `0.95`)
num_pseudo_batches: Scalar `int`-like `Tensor`. The effective number of
minibatches in the data set. Trades off noise and prior with the SGD
likelihood term. Note: Assumes the loss is taken as the mean over a
minibatch. Otherwise if the sum was taken, divide this number by the
batch size. (Default: `1`)
burnin: Scalar `int`-like `Tensor`. The number of iterations to collect
gradient statistics to update the preconditioner before starting to draw
noisy samples. (Default: `25`)
diagonal_bias: Scalar `float`-like `Tensor`. Term added to the diagonal of
the preconditioner to prevent the preconditioner from degenerating.
(Default: `1e-8`)
name: Python `str` describing ops managed by this function.
(Default: `"SGLDOptimizer"`)
variable_scope: Variable scope used for calls to `tf.get_variable`.
If `None`, a new variable scope is created using name
`ops.get_default_graph().unique_name(name or default_name)`.
Raises:
InvalidArgumentError: If preconditioner_decay_rate is a `Tensor` not in
`(0,1]`.
"""
def __init__(self,
learning_rate,
preconditioner_decay_rate=0.95,
num_pseudo_batches=1,
burnin=25,
diagonal_bias=1e-8,
name=None,
variable_scope=None):
default_name = 'SGLDOptimizer'
with ops.name_scope(name, default_name, [
learning_rate, preconditioner_decay_rate, num_pseudo_batches, burnin,
diagonal_bias
]):
if variable_scope is None:
var_scope_name = ops.get_default_graph().unique_name(
name or default_name)
with varscope_ops.variable_scope(var_scope_name) as scope:
self._variable_scope = scope
else:
self._variable_scope = variable_scope
self._preconditioner_decay_rate = ops.convert_to_tensor(
preconditioner_decay_rate, name='preconditioner_decay_rate')
self._num_pseudo_batches = ops.convert_to_tensor(
num_pseudo_batches, name='num_pseudo_batches')
self._burnin = ops.convert_to_tensor(burnin, name='burnin')
self._diagonal_bias = ops.convert_to_tensor(
diagonal_bias, name='diagonal_bias')
self._learning_rate = ops.convert_to_tensor(
learning_rate, name='learning_rate')
with varscope_ops.variable_scope(self._variable_scope):
self._counter = varscope_ops.get_variable(
'counter', initializer=0, trainable=False)
self._preconditioner_decay_rate = control_flow_ops.with_dependencies([
check_ops.assert_non_negative(
self._preconditioner_decay_rate,
message='`preconditioner_decay_rate` must be non-negative'),
check_ops.assert_less_equal(
self._preconditioner_decay_rate,
1.,
message='`preconditioner_decay_rate` must be at most 1.'),
], self._preconditioner_decay_rate)
self._num_pseudo_batches = control_flow_ops.with_dependencies([
check_ops.assert_greater(
self._num_pseudo_batches,
0,
message='`num_pseudo_batches` must be greater than zero')
], self._num_pseudo_batches)
self._burnin = control_flow_ops.with_dependencies([
check_ops.assert_non_negative(
self._burnin, message='`burnin` must be non-negative'),
check_ops.assert_integer(
self._burnin, message='`burnin` must be an integer')
], self._burnin)
self._diagonal_bias = control_flow_ops.with_dependencies([
check_ops.assert_non_negative(
self._diagonal_bias,
message='`diagonal_bias` must be non-negative')
], self._diagonal_bias)
super(SGLDOptimizer, self).__init__(use_locking=False,
name=name or default_name)
def _create_slots(self, var_list):
for v in var_list:
init_rms = init_ops.ones_initializer(dtype=v.dtype)
self._get_or_make_slot_with_initializer(v, init_rms, v.get_shape(),
v.dtype, 'rms', self._name)
def _prepare(self):
# We need to put the conversion and check here because a user will likely
# want to decay the learning rate dynamically.
self._learning_rate_tensor = control_flow_ops.with_dependencies([
check_ops.assert_non_negative(
self._learning_rate, message='`learning_rate` must be non-negative')
], ops.convert_to_tensor(self._learning_rate, name='learning_rate_tensor'))
self._decay_tensor = ops.convert_to_tensor(
self._preconditioner_decay_rate, name='preconditioner_decay_rate')
super(SGLDOptimizer, self)._prepare()
def _apply_dense(self, grad, var):
rms = self.get_slot(var, 'rms')
with ops.control_dependencies([
self._update_momentum(rms, grad, math_ops.cast(self._decay_tensor,
var.dtype.base_dtype))]):
new_grad = self._apply_noisy_update(rms, grad)
return training_ops.apply_gradient_descent(
var,
math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
new_grad,
use_locking=self._use_locking).op
def _apply_sparse(self, grad, var):
rms = self.get_slot(var, 'rms')
with ops.control_dependencies([
self._update_momentum(rms, grad, math_ops.cast(self._decay_tensor,
var.dtype.base_dtype))]):
new_grad = self._apply_noisy_update(rms, grad)
return training_ops.apply_gradient_descent(
var,
math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
new_grad,
use_locking=self._use_locking).op
@property
def variable_scope(self):
"""Variable scope of all calls to `tf.get_variable`."""
return self._variable_scope
def _apply_noisy_update(self, mom, grad):
# Compute and apply the gradient update following
# preconditioned Langevin dynamics
stddev = array_ops.where(
array_ops.squeeze(self._counter > self._burnin),
math_ops.cast(math_ops.rsqrt(self._learning_rate), grad.dtype),
array_ops.zeros([], grad.dtype))
preconditioner = math_ops.rsqrt(
mom + math_ops.cast(self._diagonal_bias, grad.dtype))
return (
0.5 * preconditioner * grad * math_ops.cast(self._num_pseudo_batches,
grad.dtype) +
random_ops.random_normal(array_ops.shape(grad), 1.0, dtype=grad.dtype) *
stddev * math_ops.sqrt(preconditioner))
def _update_momentum(self, mom, grad, decay):
# Keep an exponentially weighted moving average of squared gradients.
# Not thread safe
return mom.assign_add((1.0 - decay) * (math_ops.square(grad) - mom))

View File

@ -63,19 +63,26 @@ const char* kPredictionsTensorName = "predictions";
void CalculateTreesToInclude(
const boosted_trees::trees::DecisionTreeEnsembleConfig& config,
const std::vector<int32>& trees_to_drop, const int32 num_trees,
const bool only_finalized, std::vector<int32>* trees_to_include) {
const bool only_finalized, const bool center_bias,
std::vector<int32>* trees_to_include) {
trees_to_include->reserve(num_trees - trees_to_drop.size());
int32 index = 0;
// This assumes that trees_to_drop is a sorted list of tree ids.
for (int32 tree = 0; tree < num_trees; ++tree) {
if ((!trees_to_drop.empty() && index < trees_to_drop.size() &&
trees_to_drop[index] == tree) ||
(only_finalized && config.tree_metadata_size() > 0 &&
!config.tree_metadata(tree).is_finalized())) {
// Skip the tree if tree is in the list of trees_to_drop.
if (!trees_to_drop.empty() && index < trees_to_drop.size() &&
trees_to_drop[index] == tree) {
++index;
continue;
}
// Or skip if the tree is not finalized and only_finalized is set,
// with the exception of centering bias.
if (only_finalized && !(center_bias && tree == 0) &&
config.tree_metadata_size() > 0 &&
!config.tree_metadata(tree).is_finalized()) {
continue;
}
trees_to_include->push_back(tree);
}
}
@ -250,7 +257,7 @@ class GradientTreesPredictionOp : public OpKernel {
CalculateTreesToInclude(
ensemble_resource->decision_tree_ensemble(), dropped_trees,
ensemble_resource->decision_tree_ensemble().trees_size(),
only_finalized_trees_, &trees_to_include);
only_finalized_trees_, center_bias_, &trees_to_include);
// Allocate output predictions matrix.
Tensor* output_predictions_t = nullptr;

View File

@ -237,6 +237,7 @@ class CenterTreeEnsembleBiasOp : public OpKernel {
VLOG(1) << "Continuing to center bias, delta=" << total_delta;
} else {
VLOG(1) << "Done centering bias, delta=" << total_delta;
ensemble_resource->LastTreeMetadata()->set_is_finalized(true);
}
Tensor* continue_centering_t = nullptr;
OP_REQUIRES_OK(
@ -260,7 +261,6 @@ class CenterTreeEnsembleBiasOp : public OpKernel {
for (size_t idx = 0; idx < logits_dimension; ++idx) {
leaf->mutable_vector()->add_value(0.0);
}
ensemble_resource->LastTreeMetadata()->set_is_finalized(true);
return leaf;
} else if (num_trees == 1) {
// Confirms that the only tree is a bias and returns its leaf.

View File

@ -181,7 +181,6 @@ class CenterTreeEnsembleBiasOpTest(test_util.TensorFlowTestCase):
tree_weights: 1.0
tree_metadata {
num_layers_grown: 1
is_finalized: true
}
growing_metadata {
num_trees_attempted: 1
@ -189,7 +188,7 @@ class CenterTreeEnsembleBiasOpTest(test_util.TensorFlowTestCase):
}
"""
self.assertEqual(new_stamp, 1)
self.assertEqual(stats.num_trees, 1)
self.assertEqual(stats.num_trees, 0)
self.assertEqual(stats.num_layers, 1)
self.assertEqual(stats.active_tree, 1)
self.assertEqual(stats.active_layer, 1)
@ -231,7 +230,6 @@ class CenterTreeEnsembleBiasOpTest(test_util.TensorFlowTestCase):
tree_weights: 1.0
tree_metadata {
num_layers_grown: 1
is_finalized: true
}
growing_metadata {
num_trees_attempted: 1
@ -239,7 +237,7 @@ class CenterTreeEnsembleBiasOpTest(test_util.TensorFlowTestCase):
}
"""
self.assertEqual(new_stamp, 2)
self.assertEqual(stats.num_trees, 1)
self.assertEqual(stats.num_trees, 0)
self.assertEqual(stats.num_layers, 1)
self.assertEqual(stats.active_tree, 1)
self.assertEqual(stats.active_layer, 1)

View File

@ -238,6 +238,7 @@ add_python_module("tensorflow/python/keras/datasets")
add_python_module("tensorflow/python/keras/datasets/boston_housing")
add_python_module("tensorflow/python/keras/datasets/cifar10")
add_python_module("tensorflow/python/keras/datasets/cifar100")
add_python_module("tensorflow/python/keras/datasets/fashion_mnist")
add_python_module("tensorflow/python/keras/datasets/imdb")
add_python_module("tensorflow/python/keras/datasets/mnist")
add_python_module("tensorflow/python/keras/datasets/reuters")

View File

@ -23,7 +23,6 @@ import itertools
import numpy as np
from tensorflow.contrib.crf.python.ops import crf
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
@ -58,18 +57,19 @@ class CrfTest(test.TestCase):
def testCrfUnaryScore(self):
inputs = np.array(
[[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32)
tag_indices = np.array([1, 2, 1, 0], dtype=np.int32)
sequence_lengths = np.array(3, dtype=np.int32)
with self.test_session() as sess:
unary_score = crf.crf_unary_score(
tag_indices=array_ops.expand_dims(tag_indices, 0),
sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
inputs=array_ops.expand_dims(inputs, 0))
unary_score = array_ops.squeeze(unary_score, [0])
tf_unary_score = sess.run(unary_score)
expected_unary_score = sum(inputs[i][tag_indices[i]]
for i in range(sequence_lengths))
self.assertAllClose(tf_unary_score, expected_unary_score)
for dtype in (np.int32, np.int64):
tag_indices = np.array([1, 2, 1, 0], dtype=dtype)
sequence_lengths = np.array(3, dtype=np.int32)
with self.test_session() as sess:
unary_score = crf.crf_unary_score(
tag_indices=array_ops.expand_dims(tag_indices, 0),
sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
inputs=array_ops.expand_dims(inputs, 0))
unary_score = array_ops.squeeze(unary_score, [0])
tf_unary_score = sess.run(unary_score)
expected_unary_score = sum(inputs[i][tag_indices[i]]
for i in range(sequence_lengths))
self.assertAllClose(tf_unary_score, expected_unary_score)
def testCrfBinaryScore(self):
tag_indices = np.array([1, 2, 1, 0], dtype=np.int32)

View File

@ -193,6 +193,9 @@ def crf_unary_score(tag_indices, sequence_lengths, inputs):
offsets = array_ops.expand_dims(
math_ops.range(batch_size) * max_seq_len * num_tags, 1)
offsets += array_ops.expand_dims(math_ops.range(max_seq_len) * num_tags, 0)
# Use int32 or int64 based on tag_indices' dtype.
if tag_indices.dtype == dtypes.int64:
offsets = math_ops.to_int64(offsets)
flattened_tag_indices = array_ops.reshape(offsets + tag_indices, [-1])
unary_scores = array_ops.reshape(
@ -305,7 +308,7 @@ def viterbi_decode(score, transition_params):
Returns:
viterbi: A [seq_len] list of integers containing the highest scoring tag
indicies.
indices.
viterbi_score: A float containing the score for the Viterbi sequence.
"""
trellis = np.zeros_like(score)
@ -385,7 +388,7 @@ class CrfDecodeBackwardRnnCell(rnn_cell.RNNCell):
"""Initialize the CrfDecodeBackwardRnnCell.
Args:
num_tags: An integer.
num_tags: An integer. The number of tags.
"""
self._num_tags = num_tags
@ -434,9 +437,9 @@ def crf_decode(potentials, transition_params, sequence_length):
sequence_length: A [batch_size] vector of true sequence lengths.
Returns:
decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`.
Contains the highest scoring tag indicies.
best_score: A [batch_size] vector, containing the score of `decode_tags`.
decode_tags: A [batch_size, max_seq_len] tensor, with dtype tf.int32.
Contains the highest scoring tag indices.
best_score: A [batch_size] tensor, containing the score of decode_tags.
"""
# For simplicity, in shape comments, denote:
# 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output).

View File

@ -270,6 +270,7 @@ py_test(
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/data/python/ops:dataset_ops",
"//tensorflow/contrib/data/python/ops:iterator_ops",
"//tensorflow/contrib/data/python/ops:transformation_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
@ -277,15 +278,20 @@ py_test(
"//tensorflow/python:data_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:functional_ops",
"//tensorflow/python:io_ops",
"//tensorflow/python:lookup_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform",
"//tensorflow/python:random_ops",
"//tensorflow/python:script_ops",
"//tensorflow/python:string_ops",
"//tensorflow/python:training",
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
"//tensorflow/python/data/ops:iterator_ops",
"//third_party/py/numpy",
],
)

View File

@ -25,9 +25,13 @@ import numpy as np
from tensorflow.contrib.data.python.ops import dataset_ops
from tensorflow.contrib.data.python.ops import error_ops
from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import data_flow_ops
@ -40,7 +44,10 @@ from tensorflow.python.ops import script_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.util import compat
@ -668,6 +675,515 @@ class MapDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testCaptureResourceInMapFn(self):
def _build_ds(iterator):
def _map_fn(x):
get_next = iterator.get_next()
return x * get_next
return dataset_ops.Dataset.range(10).map(_map_fn)
def _build_graph():
captured_iterator = dataset_ops.Dataset.range(
10).make_initializable_iterator()
ds = _build_ds(captured_iterator)
iterator = ds.make_initializable_iterator()
init_op = iterator.initializer
return captured_iterator.initializer, init_op
with ops.Graph().as_default() as g:
captured_init_op, init_op = _build_graph()
with self.test_session(graph=g) as sess:
sess.run(captured_init_op)
with self.assertRaises(errors.UnimplementedError):
# CapturedFunction does not support capturing IteratorResource.
sess.run(init_op)
class MapDatasetSerializationTest(test.TestCase):
def setUp(self):
self._tensor_slice_len = 7
self._num_epochs = 14
self._num_outputs = self._tensor_slice_len * self._num_epochs
def tearDown(self):
# Remove all checkpoint files.
prefix = self._ckpt_path()
pattern = prefix + "*"
files = gfile.Glob(pattern)
map(gfile.Remove, files)
def _build_ds(self, multiplier=37.0):
components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) *
np.arange(self._tensor_slice_len)[:, np.newaxis],
np.array(multiplier) * np.arange(self._tensor_slice_len))
def _map_fn(x, y, z):
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
return (dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn)
.repeat(self._num_epochs))
def _build_graph(self, multiplier=37.0, build_saveable=True):
ds = self._build_ds(multiplier)
iterator = ds.make_initializable_iterator()
if build_saveable:
saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
init_op = iterator.initializer
get_next = iterator.get_next()
self._add_iterator_ops_to_collection(init_op, get_next)
saver = saver_lib.Saver(allow_empty=True)
return init_op, get_next, saver
def _build_empty_graph(self, output_types, output_shapes):
iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes)
saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
saver = saver_lib.Saver()
get_next = iterator.get_next()
return get_next, saver
def _add_iterator_ops_to_collection(self, init_op, get_next):
ops.add_to_collection("iterator_ops", init_op)
ops.add_to_collection("iterator_ops", get_next[0])
ops.add_to_collection("iterator_ops", get_next[1])
ops.add_to_collection("iterator_ops", get_next[2])
def _get_iterator_ops_from_collection(self):
init_op, get_next_1, get_next_2, get_next_3 = ops.get_collection(
"iterator_ops")
return init_op, (get_next_1, get_next_2, get_next_3)
def _ckpt_path(self):
return os.path.join(self.get_temp_dir(), "iterator")
def _latest_ckpt(self):
return saver_lib.latest_checkpoint(self.get_temp_dir())
def _save(self, sess, saver):
saver.save(sess, self._ckpt_path())
def _restore(self, saver, sess):
saver.restore(sess, self._latest_ckpt())
def _import_meta_graph(self):
meta_file_path = self._ckpt_path() + ".meta"
return saver_lib.import_meta_graph(meta_file_path)
def _testReadWithBreaks(self, break_points, init_before_restore=False):
expected = []
actual = []
# Generate the ground truth.
with ops.Graph().as_default() as g:
init_op, get_next_op, _ = self._build_graph()
with self.test_session(graph=g) as sess:
sess.run(init_op)
for _ in range(self._num_outputs):
expected.append(sess.run(get_next_op))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next_op)
# Run and checkpoint after first break_point.
with ops.Graph().as_default() as g:
init_op, get_next_op, saver = self._build_graph()
with self.test_session(graph=g) as sess:
sess.run(init_op)
for _ in range(break_points[0]):
actual.append(sess.run(get_next_op))
self._save(sess, saver)
# Load from checkpoint and continue running while stopping at each
# subsequent checkpoint.
for i in range(len(break_points)):
with ops.Graph().as_default() as g:
saver = self._import_meta_graph()
init_op, get_next_op = self._get_iterator_ops_from_collection()
with self.test_session(graph=g) as sess:
if init_before_restore:
sess.run(init_op)
self._restore(saver, sess)
start = break_points[i]
end = break_points[
i + 1] if i < len(break_points) - 1 else self._num_outputs
for _ in range(end - start):
actual.append(sess.run(get_next_op))
self._save(sess, saver)
if end == self._num_outputs:
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next_op)
self._match(expected, actual)
def _match(self, expected, actual):
self.assertEqual(len(expected), len(actual))
for expected_tuple, actual_tuple in zip(expected, actual):
self.assertEqual(expected_tuple[0], actual_tuple[0])
self.assertSequenceEqual(expected_tuple[1].tolist(),
actual_tuple[1].tolist())
self.assertEqual(expected_tuple[2], actual_tuple[2])
def _does_not_match(self, expected, actual):
with self.assertRaises(AssertionError):
self._match(expected, actual)
def testSaveRestore(self):
self._testReadWithBreaks([4])
self._testReadWithBreaks([13])
self._testReadWithBreaks([18])
self._testReadWithBreaks([23])
def testSaveUnusedIterator(self):
self._testReadWithBreaks([0])
def testSaveFullyUsedIterator(self):
self._testReadWithBreaks([self._num_outputs])
def testMultipleBreaks(self):
self._testReadWithBreaks([0, 5, 9, 15, 25, 32])
def testIdempotence(self):
# Attempt to save iterator immediately after restoring.
self._testReadWithBreaks([1, 1, 5, 5, 5, 25, 32])
def testInitThenRestore(self):
self._testReadWithBreaks([0, 5, 9, 15, 25, 32], init_before_restore=True)
def testRestoreExhaustedIterator(self):
with ops.Graph().as_default() as g:
init_op, get_next_op, saver = self._build_graph()
with self.test_session(graph=g) as sess:
sess.run(init_op)
for _ in range(self._num_outputs):
sess.run(get_next_op)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next_op)
self._save(sess, saver)
with ops.Graph().as_default() as g:
saver = self._import_meta_graph()
init_op, get_next_op = self._get_iterator_ops_from_collection()
with self.test_session(graph=g) as sess:
self._restore(saver, sess)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next_op)
def testResetRestoredIterator(self):
expected = []
# Collect ground truth containing all outputs.
with ops.Graph().as_default() as g:
init_op, get_next_op, saver = self._build_graph()
break_point = self._num_outputs // 2
with self.test_session(graph=g) as sess:
sess.run(init_op)
for _ in range(break_point):
expected.append(sess.run(get_next_op))
self._save(sess, saver)
for _ in range(self._num_outputs - break_point):
expected.append(sess.run(get_next_op))
actual = []
# Restore from checkpoint and then run init_op.
with ops.Graph().as_default() as g:
saver = self._import_meta_graph()
init_op, get_next_op = self._get_iterator_ops_from_collection()
with self.test_session(graph=g) as sess:
self._restore(saver, sess)
sess.run(init_op)
for _ in range(self._num_outputs):
actual.append(sess.run(get_next_op))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next_op)
self._match(expected, actual)
def testRestoreInModifiedGraph(self):
expected = []
actual_without_restore = []
actual = []
break_point = 10
with ops.Graph().as_default() as g:
init_op, get_next_op, saver = self._build_graph(multiplier=15.0)
with self.test_session(graph=g) as sess:
sess.run(init_op)
for _ in range(break_point):
expected.append(sess.run(get_next_op))
actual.extend(expected)
self._save(sess, saver)
for _ in range(self._num_outputs - break_point):
expected.append(sess.run(get_next_op))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next_op)
# Collect outputs by running modified graph.
with ops.Graph().as_default() as g:
init_op, get_next_op, saver = self._build_graph(multiplier=30.0)
with self.test_session(graph=g) as sess:
sess.run(init_op)
for _ in range(self._num_outputs):
actual_without_restore.append(sess.run(get_next_op))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next_op)
# Restore the checkpoint in the modified graph.
with ops.Graph().as_default() as g:
init_op, get_next_op, saver = self._build_graph(multiplier=30.0)
with self.test_session(graph=g) as sess:
self._restore(saver, sess)
for _ in range(self._num_outputs - break_point):
actual.append(sess.run(get_next_op))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next_op)
# Ensure the modified graph gets overridden when restoring checkpoint.
self._does_not_match(expected, actual_without_restore)
# Expect that the outputs are what we would expect if we ran the old
# graph.
self._match(expected, actual)
# TODO(srbs): Add this test to dataset_serialization_test_base.py.
def testRestoreInEmptyGraph(self):
expected = []
actual = []
break_point = 10
with ops.Graph().as_default() as g:
init_op, get_next_op, saver = self._build_graph(multiplier=15.0)
with self.test_session(graph=g) as sess:
sess.run(init_op)
for _ in range(break_point):
sess.run(get_next_op)
self._save(sess, saver)
for _ in range(self._num_outputs - break_point):
expected.append(sess.run(get_next_op))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next_op)
with ops.Graph().as_default() as g:
ds = self._build_ds()
output_types = ds.output_types
output_shapes = ds.output_shapes
with ops.Graph().as_default() as g:
get_next_op, saver = self._build_empty_graph(output_types, output_shapes)
with self.test_session(graph=g) as sess:
self._restore(saver, sess)
for _ in range(self._num_outputs - break_point):
actual.append(sess.run(get_next_op))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next_op)
# Expect that the outputs are what we would expect if we ran the old
# graph.
self._match(expected, actual)
def testDoNotBuildSaveable(self):
break_point = 10
with ops.Graph().as_default() as g:
init_op, get_next_op, saver = self._build_graph(multiplier=15.0)
with self.test_session(graph=g) as sess:
sess.run(init_op)
for _ in range(break_point):
sess.run(get_next_op)
self._save(sess, saver)
expected = []
# Collect ground truth by running modified graph.
with ops.Graph().as_default() as g:
init_op, get_next_op, saver = self._build_graph(multiplier=30.0)
with self.test_session(graph=g) as sess:
sess.run(init_op)
for _ in range(self._num_outputs):
expected.append(sess.run(get_next_op))
actual = []
with ops.Graph().as_default() as g:
init_op, get_next_op, saver = self._build_graph(
multiplier=30.0, build_saveable=False)
with self.test_session(graph=g) as sess:
# Since the SaveableObject was not added to Saver's list
# of saveables, iterator state is not restored by saver.restore().
self._restore(saver, sess)
with self.assertRaises(errors.FailedPreconditionError):
sess.run(get_next_op)
sess.run(init_op)
for _ in range(self._num_outputs):
actual.append(sess.run(get_next_op))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next_op)
self._match(expected, actual)
def testSaveStatefulFunction(self):
def _build_ds():
def _map_fn(x):
return random_ops.random_uniform(
(), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x)
return dataset_ops.Dataset.range(100).map(_map_fn)
def _build_graph():
ds = _build_ds()
iterator = ds.make_initializable_iterator()
saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
init_op = iterator.initializer
get_next = iterator.get_next()
saver = saver_lib.Saver(allow_empty=True)
return init_op, get_next, saver
break_point = 10
with ops.Graph().as_default() as g:
init_op, get_next_op, saver = _build_graph()
with self.test_session(graph=g) as sess:
sess.run(init_op)
for _ in range(break_point):
sess.run(get_next_op)
with self.assertRaises(errors.InvalidArgumentError):
self._save(sess, saver)
def testCaptureVariableInMapFn(self):
def _build_ds():
counter_var = variable_scope.get_variable(
"counter", (), dtypes.int32, use_resource=True)
return (dataset_ops.Dataset.from_tensors(0).repeat(10).map(
lambda _: counter_var.assign_add(1)))
def _build_graph():
ds = _build_ds()
iterator = ds.make_initializable_iterator()
saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
init_op = iterator.initializer
get_next = iterator.get_next()
saver = saver_lib.Saver(allow_empty=True)
return init_op, get_next, saver
break_point = 10
with ops.Graph().as_default() as g:
init_op, get_next_op, saver = _build_graph()
with self.test_session(graph=g) as sess:
sess.run(variables.global_variables_initializer())
sess.run(init_op)
for _ in range(break_point):
sess.run(get_next_op)
with self.assertRaises(errors.InvalidArgumentError):
self._save(sess, saver)
def testCaptureDefunInMapFn(self):
num_outputs = 100
def _build_ds():
@function.Defun(dtypes.int64)
def defun_fn(x):
return constant_op.constant(1000) + math_ops.to_int32(x)
return dataset_ops.Dataset.range(num_outputs).map(defun_fn)
def _build_graph():
ds = _build_ds()
iterator = ds.make_initializable_iterator()
saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
init_op = iterator.initializer
get_next = iterator.get_next()
saver = saver_lib.Saver(allow_empty=True)
return init_op, get_next, saver
break_point = 10
expected = []
with ops.Graph().as_default() as g:
init_op, get_next_op, saver = _build_graph()
with self.test_session(graph=g) as sess:
sess.run(variables.global_variables_initializer())
sess.run(init_op)
for _ in range(break_point):
sess.run(get_next_op)
self._save(sess, saver)
for _ in range(num_outputs - break_point):
expected.append(sess.run(get_next_op))
with ops.Graph().as_default() as g:
ds = _build_ds()
output_types = ds.output_types
output_shapes = ds.output_shapes
actual = []
with ops.Graph().as_default() as g:
get_next_op, saver = self._build_empty_graph(output_types, output_shapes)
with self.test_session(graph=g) as sess:
self._restore(saver, sess)
for _ in range(num_outputs - break_point):
actual.append(sess.run(get_next_op))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next_op)
self.assertSequenceEqual(expected, actual)
def testBuildDefunInMapFn(self):
num_outputs = 100
def _build_ds():
@function.Defun(dtypes.int64)
def defun_fn(x):
@function.Defun(dtypes.int32)
def defun_fn_deep(x):
return constant_op.constant(1000) + math_ops.to_int32(x)
return constant_op.constant(11000) + defun_fn_deep(math_ops.to_int32(x))
return dataset_ops.Dataset.range(num_outputs).map(defun_fn)
def _build_graph():
ds = _build_ds()
iterator = ds.make_initializable_iterator()
saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
init_op = iterator.initializer
get_next = iterator.get_next()
saver = saver_lib.Saver(allow_empty=True)
return init_op, get_next, saver
break_point = 10
expected = []
with ops.Graph().as_default() as g:
init_op, get_next_op, saver = _build_graph()
with self.test_session(graph=g) as sess:
sess.run(variables.global_variables_initializer())
sess.run(init_op)
for _ in range(break_point):
sess.run(get_next_op)
self._save(sess, saver)
for _ in range(num_outputs - break_point):
expected.append(sess.run(get_next_op))
with ops.Graph().as_default() as g:
ds = _build_ds()
output_types = ds.output_types
output_shapes = ds.output_shapes
actual = []
with ops.Graph().as_default() as g:
get_next_op, saver = self._build_empty_graph(output_types, output_shapes)
with self.test_session(graph=g) as sess:
self._restore(saver, sess)
for _ in range(num_outputs - break_point):
actual.append(sess.run(get_next_op))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next_op)
self.assertSequenceEqual(expected, actual)
if __name__ == "__main__":
test.main()

View File

@ -50,21 +50,22 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/contrib/data/python/ops:prefetching_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python/data/ops:iterator_ops",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/eager:context",
],
)
py_test(
cuda_py_test(
name = "datasets_test",
srcs = ["datasets_test.py"],
srcs_version = "PY2AND3",
deps = [
additional_deps = [
":datasets",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
@ -240,6 +241,7 @@ py_test(
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:training",
"//tensorflow/python:variable_scope",
"//tensorflow/python/eager:function",
"//tensorflow/python/eager:test",
],
)

View File

@ -20,11 +20,15 @@ from __future__ import print_function
import threading
from tensorflow.contrib.data.python.ops import prefetching_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.util import nest
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import resource_variable_ops
@ -32,12 +36,12 @@ _uid_counter = 0
_uid_lock = threading.Lock()
def _iterator_shared_name():
def _generate_shared_name(prefix):
with _uid_lock:
global _uid_counter
uid = _uid_counter
_uid_counter += 1
return "eager_iterator_{}".format(uid)
return "{}_{}".format(prefix, uid)
class Iterator(object):
@ -72,11 +76,12 @@ class Iterator(object):
with ops.device("/device:CPU:0"):
ds_variant = dataset._as_variant_tensor() # pylint: disable=protected-access
self._output_types = dataset.output_types
self._output_shapes = dataset.output_shapes
self._flat_output_types = nest.flatten(dataset.output_types)
self._flat_output_shapes = nest.flatten(dataset.output_shapes)
self._resource = gen_dataset_ops.iterator(
container="",
shared_name=_iterator_shared_name(),
shared_name=_generate_shared_name("eager_iterator"),
output_types=self._flat_output_types,
output_shapes=self._flat_output_shapes)
gen_dataset_ops.make_iterator(ds_variant, self._resource)
@ -84,6 +89,35 @@ class Iterator(object):
self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
handle=self._resource, handle_device="/device:CPU:0")
self._device = context.context().device_name
self._buffer_resource_handle = None
if not context.context().device_spec.device_type:
is_remote_device = False
else:
is_remote_device = context.context().device_spec.device_type != "CPU"
if is_remote_device:
with ops.device("/device:CPU:0"):
iter_string_handle = gen_dataset_ops.iterator_to_string_handle(
self._resource)
@function.Defun(dtypes.string)
def remote_fn(h):
remote_iterator = iterator_ops.Iterator.from_string_handle(
h, self._output_types, self._output_shapes)
return remote_iterator.get_next()
remote_fn.add_to_graph(None)
target = constant_op.constant("/device:CPU:0")
with ops.device(self._device):
self._buffer_resource_handle = prefetching_ops.function_buffering_resource(
string_arg=iter_string_handle,
f=remote_fn,
target_device=target,
buffer_size=10,
thread_pool_size=1,
container="",
shared_name=_generate_shared_name("function_buffer_resource"))
self._buffer_resource_deleter = resource_variable_ops.EagerResourceDeleter(
handle=self._buffer_resource_handle, handle_device=self._device)
def __iter__(self):
return self
@ -93,20 +127,20 @@ class Iterator(object):
def next(self):
"""Return the next tf.Tensor from the dataset."""
try:
# TODO(ashankar): Consider removing this ops.device() contextmanager
# and instead mimic ops placement in graphs: Operations on resource
# handles execute on the same device as where the resource is placed.
with ops.device("/device:CPU:0"):
ret = gen_dataset_ops.iterator_get_next(
self._resource,
output_types=self._flat_output_types,
output_shapes=self._flat_output_shapes)
except errors.OutOfRangeError:
raise StopIteration
# Copies tensors from CPU to the current device if necessary.
# TODO(rohanj): This should be replaced by the mechanism to have the
# runtime's threads copy tensors to the destination device.
with ops.device(self._device):
ret = [array_ops.identity(x) for x in ret]
try:
if self._buffer_resource_handle is not None:
ret = prefetching_ops.function_buffering_resource_get_next(
function_buffer_resource=self._buffer_resource_handle,
output_types=self._flat_output_types)
else:
# TODO(ashankar): Consider removing this ops.device() contextmanager
# and instead mimic ops placement in graphs: Operations on resource
# handles execute on the same device as where the resource is placed.
ret = gen_dataset_ops.iterator_get_next(
self._resource,
output_types=self._flat_output_types,
output_shapes=self._flat_output_shapes)
except errors.OutOfRangeError:
raise StopIteration
return nest.pack_sequence_as(self._output_types, ret)

View File

@ -87,7 +87,7 @@ class EvaluatorTest(test.TestCase):
e.all_metric_results(logdir)
events = summary_test_util.events_from_file(logdir)
events = summary_test_util.events_from_logdir(logdir)
self.assertEqual(len(events), 2)
self.assertEqual(events[1].summary.value[0].simple_value, 6.0)
@ -136,7 +136,7 @@ class EvaluatorTest(test.TestCase):
variables.global_variables_initializer().run()
e.run_evaluation(init_op, call_op, results_op)
events = summary_test_util.events_from_file(logdir)
events = summary_test_util.events_from_logdir(logdir)
self.assertEqual(len(events), 2)
self.assertEqual(events[1].summary.value[0].simple_value, 6.0)

View File

@ -95,7 +95,7 @@ class ResNet50GraphTest(tf.test.TestCase):
sess.run([train_op, tf.contrib.summary.all_summary_ops()],
feed_dict={images: np_images, labels: np_labels})
events = summary_test_util.events_from_file(logdir)
events = summary_test_util.events_from_logdir(logdir)
self.assertEqual(len(events), 2)
self.assertEqual(events[1].summary.value[0].tag, 'loss')

View File

@ -103,7 +103,7 @@ class ResNet50Test(tf.test.TestCase):
images, labels = random_batch(2)
train_one_step(model, images, labels, optimizer)
self.assertEqual(320, len(model.variables))
events = summary_test_util.events_from_file(logdir)
events = summary_test_util.events_from_logdir(logdir)
self.assertEqual(len(events), 2)
self.assertEqual(events[1].summary.value[0].tag, 'loss')

View File

@ -72,7 +72,7 @@ class MetricsTest(test.TestCase):
name="t0").as_default(), summary_ops.always_record_summaries():
m.result() # As a side-effect will write summaries.
events = summary_test_util.events_from_file(logdir)
events = summary_test_util.events_from_logdir(logdir)
self.assertEqual(len(events), 2)
self.assertEqual(events[1].summary.value[0].simple_value, 37.0)

File diff suppressed because it is too large Load Diff

View File

@ -19,9 +19,12 @@ from __future__ import print_function
import gc
from tensorflow.contrib.eager.python import network
from tensorflow.python.eager import context
from tensorflow.python.eager import function
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.layers import core
from tensorflow.python.ops import math_ops
@ -46,8 +49,8 @@ class NetworkTest(test.TestCase):
def _save_modify_load_network_built(self, net, global_step=None):
checkpoint_directory = self.get_temp_dir()
checkpoint_path = net.save(
save_path=checkpoint_directory, global_step=global_step)
checkpoint_path = network.save_network_checkpoint(
network=net, save_path=checkpoint_directory, global_step=global_step)
input_value = constant_op.constant([[42.0]])
original_output = self.evaluate(net(input_value))
for var in net.variables:
@ -56,13 +59,13 @@ class NetworkTest(test.TestCase):
self.evaluate(net(input_value)),
original_output)
# Either the returned explicit checkpoint path or the directory should work.
net.restore(save_path=checkpoint_directory)
network.restore_network_checkpoint(net, save_path=checkpoint_directory)
self.assertAllEqual(
original_output,
self.evaluate(net(input_value)))
for var in net.variables:
self.evaluate(var.assign(var + 2.))
net.restore(save_path=checkpoint_path)
network.restore_network_checkpoint(net, save_path=checkpoint_path)
self.assertAllEqual(
original_output,
self.evaluate(net(input_value)))
@ -85,13 +88,30 @@ class NetworkTest(test.TestCase):
result = net(constant_op.constant([[2.0]]))
self.assertEqual(34.0, self.evaluate(result))
# TODO(akshayka): This test should be changed once an API for compiling
# `call` into a defun is implemented.
def testReplacingNetworkCallWithDefun(self):
net = MyNetwork(name="abcd")
x = constant_op.constant([[2.0]])
net(x) # Force variables to be created.
self.evaluate(net.trainable_variables[0].assign([[17.0]]))
net.call = function.defun(net.call)
result = net(x) # Build and execute the TensorFlow function
self.assertEqual(34.0, self.evaluate(result))
# Force the creation of another TensorFlow function by changing input shape
y = constant_op.constant([[1.0], [2.0]])
result = net(y)
self.assertAllEqual([[17.0], [34.0]], self.evaluate(result))
# TODO(allenl): This test creates garbage in some Python versions
@test_util.run_in_graph_and_eager_modes()
def testNetworkSaveRestoreAlreadyBuilt(self):
net = MyNetwork(name="abcd")
with self.assertRaisesRegexp(
ValueError, "Attempt to save the Network before it was first called"):
net.save(self.get_temp_dir())
network.save_network_checkpoint(net, self.get_temp_dir())
net(constant_op.constant([[2.0]]))
self.evaluate(net.trainable_variables[0].assign([[17.0]]))
self._save_modify_load_network_built(net, global_step=None)
@ -105,7 +125,7 @@ class NetworkTest(test.TestCase):
self.evaluate(net.variables[0].assign([[3.]]))
default_global_step = training_util.get_or_create_global_step()
self.evaluate(default_global_step.assign(4242))
save_path = net.save(self.get_temp_dir())
save_path = network.save_network_checkpoint(net, self.get_temp_dir())
self.assertIn("abcd-4242", save_path)
# TODO(allenl): This test creates garbage in some Python versions
@ -116,16 +136,43 @@ class NetworkTest(test.TestCase):
test_input = constant_op.constant([[2.0]])
net1(test_input)
self.evaluate(net1.trainable_variables[0].assign([[17.0]]))
save_path = net1.save(save_dir)
save_path = network.save_network_checkpoint(net1, save_dir)
# With a pre-build restore we should have the same value.
net2 = MyNetwork()
net2.restore(save_path)
network.restore_network_checkpoint(net2, save_path)
self.assertAllEqual(self.evaluate(net1(test_input)),
self.evaluate(net2(test_input)))
self.assertIsNot(net1.variables[0], net2.variables[0])
self.assertAllEqual(self.evaluate(net1.variables[0]),
self.evaluate(net2.variables[0]))
@test_util.run_in_graph_and_eager_modes()
def testNetworkMatchesLayerVariableNames(self):
zero = constant_op.constant([[0.]])
layer_one = core.Dense(1, use_bias=False)
layer_one(zero)
layer_two = core.Dense(1, use_bias=False)
layer_two(zero)
class TwoLayerNet(network.Network):
def __init__(self, name=None):
super(TwoLayerNet, self).__init__(name=name)
self.first = self.track_layer(core.Dense(
1, use_bias=False))
self.second = self.track_layer(core.Dense(
1, use_bias=False))
def call(self, x):
return self.second(self.first(x))
net = TwoLayerNet()
net(zero)
self.assertEqual("two_layer_net/" + layer_one.variables[0].name,
net.first.variables[0].name)
self.assertEqual("two_layer_net/" + layer_two.variables[0].name,
net.second.variables[0].name)
@test_util.run_in_graph_and_eager_modes()
def testLoadIntoUnbuiltSharedLayer(self):
@ -173,14 +220,15 @@ class NetworkTest(test.TestCase):
# Re-map the variable names so that with default restore mapping we'll
# attempt to restore into the unbuilt Layer.
name_mapping = {
"checkpoint_creator/first_layer/kernel": "owner_1/first_layer/kernel",
"checkpoint_creator/first_layer/kernel": "owner/first_layer/kernel",
"checkpoint_creator/second_layer/kernel": "second_layer/kernel",
}
save_path = checkpoint_creator.save(
save_path = network.save_network_checkpoint(
checkpoint_creator,
self.get_temp_dir(),
map_func=lambda full_name: name_mapping[full_name])
load_into = User(use_layer=first_owner.first)
load_into.restore(save_path)
network.restore_network_checkpoint(load_into, save_path)
self.assertEqual(0, len(first_owner.variables))
self.assertAllEqual(self.evaluate(checkpoint_creator(one)),
self.evaluate(load_into(one)))
@ -196,12 +244,13 @@ class NetworkTest(test.TestCase):
del first_owner
gc.collect()
def _restore_map_func(original_name):
if original_name.startswith("owner_1"):
return original_name.replace("owner_1", "owner_2")
if original_name.startswith("owner/"):
return original_name.replace("owner/", "owner_1/")
else:
return "user_2/" + original_name
return "user_1/" + original_name
with self.assertRaisesRegexp(ValueError, "garbage collected"):
load_into.restore(save_path, map_func=_restore_map_func)
network.restore_network_checkpoint(
load_into, save_path, map_func=_restore_map_func)
@test_util.run_in_graph_and_eager_modes()
def testRestoreIntoSubNetwork(self):
@ -221,17 +270,18 @@ class NetworkTest(test.TestCase):
whole_model_saver(one)
self.evaluate(whole_model_saver.variables[0].assign([[15.]]))
self.evaluate(whole_model_saver.variables[1].assign([[16.]]))
whole_model_checkpoint = whole_model_saver.save(self.get_temp_dir())
whole_model_checkpoint = network.save_network_checkpoint(
whole_model_saver, self.get_temp_dir())
save_from = MyNetwork()
save_from(one)
self.evaluate(save_from.variables[0].assign([[5.]]))
checkpoint = save_from.save(self.get_temp_dir())
checkpoint = network.save_network_checkpoint(save_from, self.get_temp_dir())
save_into_parent = Parent()
save_into_parent.restore(whole_model_checkpoint)
save_into_parent.first.restore(checkpoint)
save_into_parent.first.restore(checkpoint) # deferred loading multiple
# times is fine
network.restore_network_checkpoint(save_into_parent, whole_model_checkpoint)
network.restore_network_checkpoint(save_into_parent.first, checkpoint)
# deferred loading multiple times is fine
network.restore_network_checkpoint(save_into_parent.first, checkpoint)
save_into_parent(one) # deferred loading
self.assertAllEqual([[5.]], self.evaluate(save_into_parent.variables[0]))
self.assertAllEqual([[16.]], self.evaluate(save_into_parent.variables[1]))
@ -240,9 +290,9 @@ class NetworkTest(test.TestCase):
# (deferred restoration should happen the same way non-deferred happens,
# with later restorations overwriting older ones).
save_into_parent = Parent()
save_into_parent.first.restore(checkpoint) # deferred loading multiple
# times is fine
save_into_parent.restore(whole_model_checkpoint)
# deferred loading multiple times is fine
network.restore_network_checkpoint(save_into_parent.first, checkpoint)
network.restore_network_checkpoint(save_into_parent, whole_model_checkpoint)
save_into_parent(one) # deferred loading
# We've overwritten the sub-Network restore.
self.assertAllEqual([[15.]], self.evaluate(save_into_parent.variables[0]))
@ -250,12 +300,12 @@ class NetworkTest(test.TestCase):
self.evaluate(save_into_parent.variables[0].assign([[3.]]))
self.evaluate(save_into_parent.variables[1].assign([[4.]]))
save_into_parent.second.restore(checkpoint)
network.restore_network_checkpoint(save_into_parent.second, checkpoint)
self.assertAllEqual([[5.]], self.evaluate(save_into_parent.variables[1]))
with self.assertRaisesRegexp(errors_impl.NotFoundError,
"not found in checkpoint"):
# The checkpoint is incompatible.
save_into_parent.restore(checkpoint)
network.restore_network_checkpoint(save_into_parent, checkpoint)
@test_util.run_in_graph_and_eager_modes()
def testCustomMapCollisionErrors(self):
@ -277,31 +327,36 @@ class NetworkTest(test.TestCase):
self.evaluate(make_checkpoint.variables[1].assign([[3.]]))
with self.assertRaisesRegexp(
ValueError,
"The map_func passed to Network.save for the Network 'parent_1' "
"resulted in two variables named 'foo'"):
make_checkpoint.save(self.get_temp_dir(), map_func=lambda n: "foo")
checkpoint = make_checkpoint.first.save(
self.get_temp_dir(), map_func=lambda n: "foo")
"The map_func passed to save_network_checkpoint for the Network "
"'parent' resulted in two variables named 'foo'"):
network.save_network_checkpoint(
make_checkpoint, self.get_temp_dir(), map_func=lambda n: "foo")
checkpoint = network.save_network_checkpoint(
network=make_checkpoint.first,
save_path=self.get_temp_dir(),
map_func=lambda n: "foo")
loader = Parent()
loader.restore(checkpoint, map_func=lambda n: "foo")
network.restore_network_checkpoint(
loader, checkpoint, map_func=lambda n: "foo")
with self.assertRaisesRegexp(
ValueError,
("The map_func passed to Network.restore for the Network"
" 'parent_2' resulted in two variables named 'foo'")):
("The map_func passed to restore_network_checkpoint for the Network"
" 'parent_1' resulted in two variables named 'foo'")):
loader(one)
loader = Parent()
loader(one)
with self.assertRaisesRegexp(
ValueError,
("The map_func passed to Network.restore for the Network"
" 'parent_3' resulted in two variables named 'foo'")):
loader.restore(checkpoint, map_func=lambda n: "foo")
("The map_func passed to restore_network_checkpoint for the Network"
" 'parent_2' resulted in two variables named 'foo'")):
network.restore_network_checkpoint(
loader, checkpoint, map_func=lambda n: "foo")
@test_util.run_in_graph_and_eager_modes()
def testDefaultMapCollisionErrors(self):
one = constant_op.constant([[1.]])
first = core.Dense(1, name="dense_1", use_bias=False)
first = core.Dense(1, name="dense", use_bias=False)
first(one)
class Parent(network.Network):
@ -322,8 +377,8 @@ class NetworkTest(test.TestCase):
with self.assertRaisesRegexp(
ValueError,
("The default checkpoint variable name mapping strategy for Network "
"'parent_1' resulted in a naming conflict.")):
make_checkpoint.save(self.get_temp_dir())
"'parent' resulted in a naming conflict.")):
network.save_network_checkpoint(make_checkpoint, self.get_temp_dir())
class Compatible(network.Network):
@ -337,14 +392,15 @@ class NetworkTest(test.TestCase):
successful_checkpoint = Compatible()
successful_checkpoint(one)
self.evaluate(successful_checkpoint.variables[0].assign([[-1.]]))
checkpoint_path = successful_checkpoint.save(self.get_temp_dir())
checkpoint_path = network.save_network_checkpoint(
successful_checkpoint, self.get_temp_dir())
load_checkpoint = Parent()
load_checkpoint(one)
with self.assertRaisesRegexp(
ValueError,
("The default checkpoint variable name mapping strategy for Network "
"'parent_2' resulted in a naming conflict.")):
load_checkpoint.restore(checkpoint_path)
"'parent_1' resulted in a naming conflict.")):
network.restore_network_checkpoint(load_checkpoint, checkpoint_path)
def testNoReferenceCyclesAfterCall(self):
@ -398,6 +454,36 @@ class NetworkTest(test.TestCase):
self.assertIsInstance(net.trainable_weights[0],
resource_variable_ops.ResourceVariable)
def testGraphOpNames(self):
"""Network operation names should match variable naming."""
def _check_op_prefixes(expected_prefix, checked_ops):
for operation in ops.get_default_graph().get_operations():
if operation.name == "ignore":
continue
if operation.name in checked_ops:
continue
checked_ops.add(operation.name)
self.assertStartsWith(expected_start=expected_prefix,
actual=operation.name)
self.assertNotIn("my_network", operation.name[len(expected_prefix):])
self.assertNotIn("dense", operation.name[len(expected_prefix):])
with context.graph_mode():
net = MyNetwork()
zero = constant_op.constant([[0.]], name="ignore")
net(zero)
checked_ops = set()
_check_op_prefixes(expected_prefix="my_network/dense/",
checked_ops=checked_ops)
net.net2 = net.track_layer(MyNetwork())
net.net2(zero)
_check_op_prefixes(expected_prefix="my_network/my_network/dense/",
checked_ops=checked_ops)
MyNetwork()(zero)
_check_op_prefixes(expected_prefix="my_network_1/dense/",
checked_ops=checked_ops)
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
def testDuplicateNameError(self):
one = constant_op.constant([[1.]])
@ -414,25 +500,25 @@ class NetworkTest(test.TestCase):
# Naming happens in the order of first build rather than the order of
# construction, but for clarity they're the same here and construction is
# annotated.
outside_net_before = MyNetwork() # name=my_network_1
outside_net_before = MyNetwork() # name=my_network
outside_net_before(one)
captured_scope = variable_scope.get_variable_scope()
with variable_scope.variable_scope("outside_scope"):
net1 = MyNetwork() # name=outside_scope/my_network_1
net1 = MyNetwork() # name=outside_scope/my_network
net1(one)
name_conflict1 = MyNetwork(name="name_conflict") # fine, unique so far
name_conflict2 = MyNetwork(name="name_conflict") # error on build
with variable_scope.variable_scope("inside_scope"):
# No issue here since the name is unique within its scope.
name_conflict3 = MyNetwork(name="name_conflict")
net2 = MyNetwork() # name=outside_scope/my_network_3 to avoid the
# variable_scope my_network_2 below.
net2 = MyNetwork() # name=outside_scope/my_network_2 to avoid the
# variable_scope my_network_1 below.
vs_name_conflict = MyNetwork(name="vs_name_conflict") # conflict below
with variable_scope.variable_scope("intervening_scope"):
with variable_scope.variable_scope(captured_scope):
with variable_scope.variable_scope("outside_scope"):
name_conflict4 = MyNetwork(name="name_conflict") # error on build
with variable_scope.variable_scope("my_network_2"):
with variable_scope.variable_scope("my_network_1"):
pass
with variable_scope.variable_scope("vs_name_conflict"):
pass
@ -452,35 +538,35 @@ class NetworkTest(test.TestCase):
self.assertEqual("outside_scope/name_conflict",
name_conflict1.name)
self.assertStartsWith(
expected_start="outside_scope/name_conflict/dense_1/",
expected_start="outside_scope/name_conflict/dense/",
actual=name_conflict1.variables[0].name)
self.assertEqual("outside_scope/inside_scope/name_conflict",
name_conflict3.name)
self.assertStartsWith(
expected_start="outside_scope/inside_scope/name_conflict/dense_1/",
expected_start="outside_scope/inside_scope/name_conflict/dense/",
actual=name_conflict3.variables[0].name)
self.assertEqual("outside_scope/my_network_1", net1.name)
self.assertEqual("outside_scope/my_network", net1.name)
self.assertStartsWith(
expected_start="outside_scope/my_network_1/dense_1/",
expected_start="outside_scope/my_network/dense/",
actual=net1.trainable_weights[0].name)
self.assertEqual("outside_scope/my_network_3", net2.name)
self.assertEqual("outside_scope/my_network_2", net2.name)
self.assertStartsWith(
expected_start="outside_scope/my_network_3/dense_1/",
expected_start="outside_scope/my_network_2/dense/",
actual=net2.trainable_weights[0].name)
net3(one)
self.assertEqual("outside_scope/my_network_4", net3.name)
self.assertEqual("outside_scope/my_network_3", net3.name)
self.assertStartsWith(
expected_start="outside_scope/my_network_4/dense_1/",
expected_start="outside_scope/my_network_3/dense/",
actual=net3.trainable_weights[0].name)
outside_net_after = MyNetwork()
outside_net_after(one)
self.assertEqual("my_network_1", outside_net_before.name)
self.assertEqual("my_network", outside_net_before.name)
self.assertStartsWith(
expected_start="my_network_1/dense_1/",
expected_start="my_network/dense/",
actual=outside_net_before.trainable_weights[0].name)
self.assertEqual("my_network_2", outside_net_after.name)
self.assertEqual("my_network_1", outside_net_after.name)
self.assertStartsWith(
expected_start="my_network_2/dense_1/",
expected_start="my_network_1/dense/",
actual=outside_net_after.trainable_weights[0].name)
@test_util.run_in_graph_and_eager_modes()
@ -490,21 +576,21 @@ class NetworkTest(test.TestCase):
net = MyNetwork()
net(constant_op.constant([[2.0]]))
self.evaluate(net.variables[0].assign([[42.]]))
self.assertEqual(net.name, "scope1/scope2/my_network_1")
self.assertEqual(net.name, "scope1/scope2/my_network")
self.assertStartsWith(
expected_start="scope1/scope2/my_network_1/dense_1/",
expected_start="scope1/scope2/my_network/dense/",
actual=net.trainable_weights[0].name)
save_path = net.save(self.get_temp_dir())
self.assertIn("scope1_scope2_my_network_1", save_path)
save_path = network.save_network_checkpoint(net, self.get_temp_dir())
self.assertIn("scope1_scope2_my_network", save_path)
restore_net = MyNetwork()
# Delayed restoration
restore_net.restore(save_path)
network.restore_network_checkpoint(restore_net, save_path)
restore_net(constant_op.constant([[1.0]]))
self.assertAllEqual([[42.]],
self.evaluate(restore_net.variables[0]))
self.evaluate(restore_net.variables[0].assign([[-1.]]))
# Immediate restoration
restore_net.restore(save_path)
network.restore_network_checkpoint(restore_net, save_path)
self.assertAllEqual([[42.]],
self.evaluate(restore_net.variables[0]))
@ -523,7 +609,7 @@ class NetworkTest(test.TestCase):
one = constant_op.constant([[1.]])
net = ParentNetwork()
net(one)
self.assertStartsWith(expected_start="parent_network_1/explicit_name/",
self.assertStartsWith(expected_start="parent_network/explicit_name/",
actual=net.trainable_weights[0].name)
self.assertEqual("explicit_name", net.first.name)
@ -578,15 +664,15 @@ class NetworkTest(test.TestCase):
# locally so that previous Layer consutrciton does not interfere with
# variable naming (e.g. add a Layer construction before the Network,
# suddenly your previously saved checkpoint is incompatible).
self.assertEqual("dense_1", net1.l1.name)
self.assertEqual("dense_1", net2.l1.name)
self.assertEqual("dense", net1.l1.name)
self.assertEqual("dense", net2.l1.name)
self.evaluate(net1.trainable_weights[0].assign([[1.]]))
self.evaluate(net2.trainable_weights[0].assign([[2.]]))
self.assertEqual(2., self.evaluate(net2.trainable_weights[0]))
self.assertEqual(1., self.evaluate(net1.trainable_weights[0]))
self.assertStartsWith(expected_start="my_network_1/dense_1/",
self.assertStartsWith(expected_start="my_network/dense/",
actual=net1.trainable_weights[0].name)
self.assertStartsWith(expected_start="my_network_2/dense_1/",
self.assertStartsWith(expected_start="my_network_1/dense/",
actual=net2.trainable_weights[0].name)
@test_util.run_in_graph_and_eager_modes()
@ -607,31 +693,31 @@ class NetworkTest(test.TestCase):
one = constant_op.constant([[1.]])
net = ParentNetwork()
net(one)
self.assertStartsWith(expected_start="parent_network_1/my_network_1/dense",
self.assertStartsWith(expected_start="parent_network/my_network/dense",
actual=net.trainable_weights[0].name)
self.assertStartsWith(expected_start="parent_network_1/my_network_1/dense",
self.assertStartsWith(expected_start="parent_network/my_network/dense",
actual=net.first.trainable_weights[0].name)
self.assertStartsWith(expected_start="parent_network_1/my_network_2/dense",
self.assertStartsWith(expected_start="parent_network/my_network_1/dense",
actual=net.trainable_weights[1].name)
self.assertStartsWith(expected_start="parent_network_1/my_network_2/dense",
self.assertStartsWith(expected_start="parent_network/my_network_1/dense",
actual=net.second.trainable_weights[0].name)
self.assertEqual("parent_network_1", net.name)
self.assertEqual("my_network_1", net.first.name)
self.assertEqual("my_network_2", net.second.name)
self.assertEqual("parent_network", net.name)
self.assertEqual("my_network", net.first.name)
self.assertEqual("my_network_1", net.second.name)
net2 = ParentNetwork()
net2(one)
self.assertStartsWith(expected_start="parent_network_2/my_network_1/dense",
self.assertStartsWith(expected_start="parent_network_1/my_network/dense",
actual=net2.trainable_weights[0].name)
self.assertStartsWith(expected_start="parent_network_2/my_network_1/dense",
self.assertStartsWith(expected_start="parent_network_1/my_network/dense",
actual=net2.first.trainable_weights[0].name)
self.assertStartsWith(expected_start="parent_network_2/my_network_2/dense",
self.assertStartsWith(expected_start="parent_network_1/my_network_1/dense",
actual=net2.trainable_weights[1].name)
self.assertStartsWith(expected_start="parent_network_2/my_network_2/dense",
self.assertStartsWith(expected_start="parent_network_1/my_network_1/dense",
actual=net2.second.trainable_weights[0].name)
self.assertEqual("parent_network_2", net2.name)
self.assertEqual("my_network_1", net2.first.name)
self.assertEqual("my_network_2", net2.second.name)
self.assertEqual("parent_network_1", net2.name)
self.assertEqual("my_network", net2.first.name)
self.assertEqual("my_network_1", net2.second.name)
@test_util.run_in_graph_and_eager_modes()
def testNestableExplicit(self):
@ -692,26 +778,26 @@ class NetworkTest(test.TestCase):
one = constant_op.constant([[1.]])
net = MixedLayerNetwork()
net(one)
self.assertEqual("dense_1", net.first.name)
self.assertEqual("dense_2", net.second.name)
self.assertEqual("dense_3", net.third.name)
self.assertEqual("dense_4", net.fourth.name)
self.assertEqual("dense_5", net.fifth.name)
self.assertEqual("dense", net.first.name)
self.assertEqual("dense_1", net.second.name)
self.assertEqual("dense_2", net.third.name)
self.assertEqual("dense_3", net.fourth.name)
self.assertEqual("dense_4", net.fifth.name)
# Note that this is _not_ the default naming behavior for Layers. Layers
# which are added to Networks follow Network variable naming conventions
# (i.e. variable names = network name unless variable sharing). Nested
# Layers revert to Layer behavior.
self.assertStartsWith(expected_start="mixed_layer_network_1/dense_1/",
self.assertStartsWith(expected_start="mixed_layer_network/dense/",
actual=net.trainable_weights[0].name)
self.assertStartsWith(expected_start="mixed_layer_network_1/dense_2/",
self.assertStartsWith(expected_start="mixed_layer_network/dense_1/",
actual=net.trainable_weights[1].name)
self.assertStartsWith(expected_start="mixed_layer_network_1/dense_3/",
self.assertStartsWith(expected_start="mixed_layer_network/dense_2/",
actual=net.trainable_weights[2].name)
self.assertStartsWith(expected_start="mixed_layer_network_1/dense_4/",
self.assertStartsWith(expected_start="mixed_layer_network/dense_3/",
actual=net.trainable_weights[3].name)
self.assertStartsWith(expected_start="mixed_layer_network_1/dense_5/",
self.assertStartsWith(expected_start="mixed_layer_network/dense_4/",
actual=net.trainable_weights[4].name)
self.assertEqual("mixed_layer_network_1", net.name)
self.assertEqual("mixed_layer_network", net.name)
@test_util.run_in_graph_and_eager_modes()
def testNestableExplicitCollisions(self):
@ -764,24 +850,24 @@ class NetworkTest(test.TestCase):
net = ParentNetwork()
net(one)
self.assertStartsWith(
expected_start="parent_network_1/first_unique_child_name/dense_1/",
expected_start="parent_network/first_unique_child_name/dense/",
actual=net.trainable_weights[0].name)
self.assertStartsWith(
expected_start="parent_network_1/second_unique_child_name/dense_1/",
expected_start="parent_network/second_unique_child_name/dense/",
actual=net.trainable_weights[1].name)
self.assertEqual("parent_network_1", net.name)
self.assertEqual("parent_network", net.name)
self.assertEqual("first_unique_child_name", net.first.name)
self.assertEqual("second_unique_child_name", net.second.name)
net2 = ParentNetwork()
net2(one)
self.assertStartsWith(
expected_start="parent_network_2/first_unique_child_name/dense",
expected_start="parent_network_1/first_unique_child_name/dense",
actual=net2.trainable_weights[0].name)
self.assertStartsWith(
expected_start="parent_network_2/second_unique_child_name/dense",
expected_start="parent_network_1/second_unique_child_name/dense",
actual=net2.trainable_weights[1].name)
self.assertEqual("parent_network_2", net2.name)
self.assertEqual("parent_network_1", net2.name)
self.assertEqual("first_unique_child_name", net2.first.name)
self.assertEqual("second_unique_child_name", net2.second.name)
@ -839,15 +925,15 @@ class NetworkTest(test.TestCase):
net2(one)
self.assertStartsWith(
expected_start="first_parent_network_1/my_network_1/dense_1/",
expected_start="first_parent_network/my_network/dense/",
actual=net2.trainable_weights[0].name)
self.assertStartsWith(
expected_start="second_parent_network_1/my_network_1/dense_1/",
expected_start="second_parent_network/my_network/dense/",
actual=net2.trainable_weights[1].name)
self.assertEqual("second_parent_network_1", net2.name)
self.assertEqual("second_parent_network", net2.name)
self.assertTrue(net2.first is net.first)
self.assertEqual("my_network_1", net2.first.name)
self.assertEqual("my_network_1", net2.second.name)
self.assertEqual("my_network", net2.first.name)
self.assertEqual("my_network", net2.second.name)
# No name collision; the owned Network is added first and has a different
# name than the shared Network.
@ -865,15 +951,15 @@ class NetworkTest(test.TestCase):
net3(one)
self.assertStartsWith(
expected_start="third_parent_network_1/my_network_1/dense",
expected_start="third_parent_network/my_network/dense",
actual=net3.trainable_weights[0].name)
self.assertStartsWith(
expected_start="first_parent_network_1/my_network_2/dense",
expected_start="first_parent_network/my_network_1/dense",
actual=net3.trainable_weights[1].name)
self.assertEqual("third_parent_network_1", net3.name)
self.assertEqual("third_parent_network", net3.name)
self.assertTrue(net3.second is net.second)
self.assertEqual("my_network_1", net3.first.name)
self.assertEqual("my_network_2", net3.second.name)
self.assertEqual("my_network", net3.first.name)
self.assertEqual("my_network_1", net3.second.name)
# "Unavoidable" same-name Layer. The owned name is added first (fixed), then
# a shared Network is added with the same name.
@ -891,15 +977,15 @@ class NetworkTest(test.TestCase):
net4(one)
self.assertStartsWith(
expected_start="fourth_parent_network_1/my_network_1/dense_1/",
expected_start="fourth_parent_network/my_network/dense/",
actual=net4.trainable_weights[0].name)
self.assertStartsWith(
expected_start="first_parent_network_1/my_network_1/dense_1/",
expected_start="first_parent_network/my_network/dense/",
actual=net4.trainable_weights[1].name)
self.assertEqual("fourth_parent_network_1", net4.name)
self.assertEqual("fourth_parent_network", net4.name)
self.assertTrue(net4.second is net.first)
self.assertEqual("my_network_1", net4.first.name)
self.assertEqual("my_network_1", net4.second.name)
self.assertEqual("my_network", net4.first.name)
self.assertEqual("my_network", net4.second.name)
@test_util.run_in_graph_and_eager_modes()
def testRecursiveLayerRenaming(self):
@ -930,28 +1016,28 @@ class NetworkTest(test.TestCase):
net(one)
self.assertStartsWith(
expected_start=("parent_network_1/network_with_layer_children_1/"
"dense_1/"),
expected_start=("parent_network/network_with_layer_children/"
"dense/"),
actual=net.trainable_weights[0].name)
self.assertStartsWith(
expected_start=("parent_network_1/network_with_layer_children_1/"
"dense_2/"),
expected_start=("parent_network/network_with_layer_children/"
"dense_1/"),
actual=net.trainable_weights[1].name)
self.assertStartsWith(
expected_start=("parent_network_1/network_with_layer_children_2/"
"dense_1/"),
expected_start=("parent_network/network_with_layer_children_1/"
"dense/"),
actual=net.trainable_weights[2].name)
self.assertStartsWith(
expected_start=("parent_network_1/network_with_layer_children_2/"
"dense_2/"),
expected_start=("parent_network/network_with_layer_children_1/"
"dense_1/"),
actual=net.trainable_weights[3].name)
self.assertEqual("parent_network_1", net.name)
self.assertEqual("network_with_layer_children_1", net.first.name)
self.assertEqual("network_with_layer_children_2", net.second.name)
self.assertEqual("dense_1", net.first.first.name)
self.assertEqual("dense_2", net.first.second.name)
self.assertEqual("dense_1", net.second.first.name)
self.assertEqual("dense_2", net.second.second.name)
self.assertEqual("parent_network", net.name)
self.assertEqual("network_with_layer_children", net.first.name)
self.assertEqual("network_with_layer_children_1", net.second.name)
self.assertEqual("dense", net.first.first.name)
self.assertEqual("dense_1", net.first.second.name)
self.assertEqual("dense", net.second.first.name)
self.assertEqual("dense_1", net.second.second.name)
@test_util.run_in_graph_and_eager_modes()
def testCallInDifferentOrderThanConstruct(self):
@ -985,23 +1071,23 @@ class NetworkTest(test.TestCase):
net1(one)
self.assertStartsWith(
expected_start="first_network_1/my_network_1/dense_1/",
expected_start="first_network/my_network/dense/",
actual=net1.trainable_weights[0].name)
self.assertStartsWith(
expected_start="first_network_1/my_network_2/dense_1/",
expected_start="first_network/my_network_1/dense/",
actual=net1.trainable_weights[1].name)
self.assertStartsWith(
expected_start="first_network_1/my_network_1/dense_1/",
expected_start="first_network/my_network/dense/",
actual=net2.trainable_weights[0].name)
self.assertStartsWith(
expected_start="second_network_1/my_network_1/dense_1/",
expected_start="second_network/my_network/dense/",
actual=net2.trainable_weights[1].name)
self.assertTrue(net1.trainable_weights[0] is net2.trainable_weights[0])
self.assertEqual("first_network_1", net1.name)
self.assertEqual("my_network_1", net1.first.name)
self.assertEqual("my_network_2", net1.second.name)
self.assertEqual("first_network", net1.name)
self.assertEqual("my_network", net1.first.name)
self.assertEqual("my_network_1", net1.second.name)
self.assertTrue(net2.first is net1.first)
self.assertEqual("my_network_1", net2.second.name)
self.assertEqual("my_network", net2.second.name)
@test_util.run_in_graph_and_eager_modes()
def testLayerCallInDifferentOrderThanConstruct(self):
@ -1038,23 +1124,23 @@ class NetworkTest(test.TestCase):
net1(one)
self.assertStartsWith(
expected_start="first_network_1/dense_1/",
expected_start="first_network/dense/",
actual=net1.trainable_weights[0].name)
self.assertStartsWith(
expected_start="first_network_1/dense_2/",
expected_start="first_network/dense_1/",
actual=net1.trainable_weights[1].name)
self.assertStartsWith(
expected_start="first_network_1/dense_1/",
expected_start="first_network/dense/",
actual=net2.trainable_weights[0].name)
self.assertStartsWith(
expected_start="second_network_1/dense_1/",
expected_start="second_network/dense/",
actual=net2.trainable_weights[1].name)
self.assertTrue(net1.trainable_weights[0] is net2.trainable_weights[0])
self.assertEqual("first_network_1", net1.name)
self.assertEqual("dense_1", net1.first.name)
self.assertEqual("dense_2", net1.second.name)
self.assertEqual("first_network", net1.name)
self.assertEqual("dense", net1.first.name)
self.assertEqual("dense_1", net1.second.name)
self.assertTrue(net2.first is net1.first)
self.assertEqual("dense_1", net2.second.name)
self.assertEqual("dense", net2.second.name)
@test_util.run_in_graph_and_eager_modes()
def testLayerAlreadyBuilt(self):
@ -1083,13 +1169,13 @@ class NetworkTest(test.TestCase):
# do not match their layer names.
actual=net.trainable_weights[0].name)
self.assertStartsWith(
expected_start="first_network_1/dense_1/",
expected_start="first_network/dense/",
actual=net.trainable_weights[1].name)
self.assertTrue(
net.trainable_weights[0] is shared_layer.trainable_weights[0])
self.assertEqual("first_network_1", net.name)
self.assertEqual("first_network", net.name)
self.assertEqual("dense_3", net.first.name)
self.assertEqual("dense_1", net.second.name)
self.assertEqual("dense", net.second.name)
class SequentialTest(test.TestCase):

View File

@ -30,9 +30,6 @@ To use, at program startup, call `tfe.enable_eager_execution()`.
@@value_and_gradients_function
@@GradientTape
@@enable_tracing
@@flush_trace
@@run
@@enable_eager_execution
@ -46,13 +43,16 @@ To use, at program startup, call `tfe.enable_eager_execution()`.
@@seterr
@@Iterator
@@Network
@@Saver
@@restore_variables_on_create
@@Variable
@@get_optimizer_variables
@@EagerVariableStore
@@Network
@@save_network_checkpoint
@@restore_network_checkpoint
@@in_eager_mode
@@in_graph_mode
@ -74,6 +74,8 @@ from __future__ import print_function
from tensorflow.contrib.eager.python import metrics
from tensorflow.contrib.eager.python.datasets import Iterator
from tensorflow.contrib.eager.python.network import Network
from tensorflow.contrib.eager.python.network import save_network_checkpoint
from tensorflow.contrib.eager.python.network import restore_network_checkpoint
from tensorflow.contrib.eager.python.saver import get_optimizer_variables
from tensorflow.contrib.eager.python.saver import restore_variables_on_create
from tensorflow.contrib.eager.python.saver import Saver
@ -86,7 +88,6 @@ from tensorflow.python.eager.context import in_eager_mode
from tensorflow.python.eager.context import in_graph_mode
from tensorflow.python.eager.context import list_devices
from tensorflow.python.eager.context import num_gpus
from tensorflow.python.eager.core import enable_tracing
from tensorflow.python.eager.custom_gradient import custom_gradient
from tensorflow.python.eager.execution_callbacks import add_execution_callback
from tensorflow.python.eager.execution_callbacks import clear_execution_callbacks

View File

@ -208,6 +208,7 @@ py_library(
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:metrics",
"//tensorflow/python:summary",
"//tensorflow/python/estimator:head",
"//tensorflow/python/estimator:metric_keys",

View File

@ -27,6 +27,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics as metrics_lib
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.summary import summary
@ -342,14 +343,19 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access
predictions = {}
metrics = {}
losses = []
for head, spec in zip(self._heads, all_estimator_spec):
losses.append(spec.loss)
head_name = head.name
# Metric keys already contain head.name.
metrics.update(spec.eval_metric_ops or {})
for k, v in six.iteritems(spec.predictions):
predictions[(head_name, k)] = v
loss = _merge_losses(losses, self._head_weights)
with ops.name_scope('merge_eval'):
for head, spec in zip(self._heads, all_estimator_spec):
losses.append(spec.loss)
head_name = head.name
# Loss metric is not added by default.
loss_name = head_lib._summary_key( # pylint:disable=protected-access
head_name, metric_keys.MetricKeys.LOSS)
metrics[loss_name] = metrics_lib.mean(spec.loss, name=loss_name)
# Metric keys already contain head.name.
metrics.update(spec.eval_metric_ops or {})
for k, v in six.iteritems(spec.predictions):
predictions[(head_name, k)] = v
loss = _merge_losses(losses, self._head_weights)
return model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.EVAL,

View File

@ -297,6 +297,8 @@ class MultiHeadTest(test.TestCase):
keys = metric_keys.MetricKeys
expected_metrics = {
keys.LOSS + '/head1': expected_loss_head1,
keys.LOSS + '/head2': expected_loss_head2,
# Average loss over examples.
keys.LOSS_MEAN + '/head1': expected_loss_head1 / 2,
keys.LOSS_MEAN + '/head2': expected_loss_head2 / 2,

View File

@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.factorization.python.ops import factorization_ops
from tensorflow.contrib.framework.python.ops import variables as framework_variables
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators import model_fn
from tensorflow.python.framework import dtypes
@ -32,175 +31,64 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
class _SweepHook(session_run_hook.SessionRunHook):
"""Keeps track of row/col sweeps, and runs prep ops before each sweep."""
def __init__(self, is_row_sweep_var, train_ops, num_rows, num_cols,
input_row_indices, input_col_indices, row_prep_ops,
col_prep_ops, init_op, completed_sweeps_var):
def __init__(self, is_row_sweep_var, is_sweep_done_var, init_op,
row_prep_ops, col_prep_ops, row_train_op, col_train_op,
switch_op):
"""Initializes SweepHook.
Args:
is_row_sweep_var: A Boolean tf.Variable, determines whether we are
currently doing a row or column sweep. It is updated by the hook.
train_ops: A list of ops. The ops created by this hook will have
control dependencies on `train_ops`.
num_rows: int, the total number of rows to be processed.
num_cols: int, the total number of columns to be processed.
input_row_indices: A Tensor of type int64. The indices of the input rows
that are processed during the current sweep. All elements of
`input_row_indices` must be in [0, num_rows).
input_col_indices: A Tensor of type int64. The indices of the input
columns that are processed during the current sweep. All elements of
`input_col_indices` must be in [0, num_cols).
row_prep_ops: list of ops, to be run before the beginning of each row
sweep, in the given order.
col_prep_ops: list of ops, to be run before the beginning of each column
sweep, in the given order.
is_sweep_done_var: A Boolean tf.Variable, determines whether we are
starting a new sweep (this is used to determine when to run the prep ops
below).
init_op: op to be run once before training. This is typically a local
initialization op (such as cache initialization).
completed_sweeps_var: An integer tf.Variable, indicates the number of
completed sweeps. It is updated by the hook.
row_prep_ops: A list of TensorFlow ops, to be run before the beginning of
each row sweep (and during initialization), in the given order.
col_prep_ops: A list of TensorFlow ops, to be run before the beginning of
each column sweep (and during initialization), in the given order.
row_train_op: A TensorFlow op to be run during row sweeps.
col_train_op: A TensorFlow op to be run during column sweeps.
switch_op: A TensorFlow op to be run before each sweep.
"""
self._num_rows = num_rows
self._num_cols = num_cols
self._is_row_sweep_var = is_row_sweep_var
self._is_sweep_done_var = is_sweep_done_var
self._init_op = init_op
self._row_prep_ops = row_prep_ops
self._col_prep_ops = col_prep_ops
self._init_op = init_op
self._is_row_sweep_var = is_row_sweep_var
self._completed_sweeps_var = completed_sweeps_var
# Boolean variable that determines whether the init_ops have been run.
self._row_train_op = row_train_op
self._col_train_op = col_train_op
self._switch_op = switch_op
# Boolean variable that determines whether the init_op has been run.
self._is_initialized = False
# Ops to run jointly with train_ops, responsible for updating
# `is_row_sweep_var` and incrementing the `global_step` and
# `completed_sweeps` counters.
self._update_op, self._is_sweep_done_var, self._switch_op = (
self._create_hook_ops(input_row_indices, input_col_indices, train_ops))
def _create_hook_ops(self, input_row_indices, input_col_indices, train_ops):
"""Creates ops to update is_row_sweep_var, global_step and completed_sweeps.
Creates two boolean tensors `processed_rows` and `processed_cols`, which
keep track of which rows/cols have been processed during the current sweep.
Returns ops that should be run after each row / col update.
- When `self._is_row_sweep_var` is True, it sets
processed_rows[input_row_indices] to True.
- When `self._is_row_sweep_var` is False, it sets
processed_cols[input_col_indices] to True.
Args:
input_row_indices: A Tensor. The indices of the input rows that are
processed during the current sweep.
input_col_indices: A Tensor. The indices of the input columns that
are processed during the current sweep.
train_ops: A list of ops. The ops created by this function have control
dependencies on `train_ops`.
Returns:
A tuple consisting of:
update_op: An op to be run jointly with training. It updates the state
and increments counters (global step and completed sweeps).
is_sweep_done_var: A Boolean tf.Variable, specifies whether the sweep is
done, i.e. all rows (during a row sweep) or all columns (during a
column sweep) have been processed.
switch_op: An op to be run in `self.before_run` when the sweep is done.
"""
processed_rows_init = array_ops.fill(dims=[self._num_rows], value=False)
with ops.colocate_with(processed_rows_init):
processed_rows = variable_scope.variable(
processed_rows_init,
collections=[ops.GraphKeys.GLOBAL_VARIABLES],
trainable=False,
name="sweep_hook_processed_rows")
processed_cols_init = array_ops.fill(dims=[self._num_cols], value=False)
with ops.colocate_with(processed_cols_init):
processed_cols = variable_scope.variable(
processed_cols_init,
collections=[ops.GraphKeys.GLOBAL_VARIABLES],
trainable=False,
name="sweep_hook_processed_cols")
switch_ops = control_flow_ops.group(
state_ops.assign(
self._is_row_sweep_var,
math_ops.logical_not(self._is_row_sweep_var)),
state_ops.assign(processed_rows, processed_rows_init),
state_ops.assign(processed_cols, processed_cols_init))
is_sweep_done_var = variable_scope.variable(
False,
collections=[ops.GraphKeys.GLOBAL_VARIABLES],
trainable=False,
name="is_sweep_done")
# After running the `train_ops`, updates `processed_rows` or
# `processed_cols` tensors, depending on whether this is a row or col sweep.
with ops.control_dependencies(train_ops):
with ops.colocate_with(processed_rows):
update_processed_rows = state_ops.scatter_update(
processed_rows,
input_row_indices,
math_ops.logical_and(
self._is_row_sweep_var,
array_ops.ones_like(input_row_indices, dtype=dtypes.bool)))
with ops.colocate_with(processed_cols):
update_processed_cols = state_ops.scatter_update(
processed_cols,
input_col_indices,
math_ops.logical_and(
math_ops.logical_not(self._is_row_sweep_var),
array_ops.ones_like(input_col_indices, dtype=dtypes.bool)))
update_processed_op = control_flow_ops.group(
update_processed_rows, update_processed_cols)
with ops.control_dependencies([update_processed_op]):
is_sweep_done = math_ops.logical_or(
math_ops.reduce_all(processed_rows),
math_ops.reduce_all(processed_cols))
# Increments global step.
global_step = framework_variables.get_global_step()
if global_step is not None:
global_step_incr_op = state_ops.assign_add(
global_step, 1, name="global_step_incr").op
else:
global_step_incr_op = control_flow_ops.no_op()
# Increments completed sweeps.
completed_sweeps_incr_op = state_ops.assign_add(
self._completed_sweeps_var,
math_ops.cast(is_sweep_done, dtypes.int32),
use_locking=True).op
update_ops = control_flow_ops.group(
global_step_incr_op,
completed_sweeps_incr_op,
state_ops.assign(is_sweep_done_var, is_sweep_done))
return update_ops, is_sweep_done_var, switch_ops
def before_run(self, run_context):
"""Runs the appropriate prep ops, and requests running update ops."""
# Runs the appropriate init ops and prep ops.
sess = run_context.session
is_sweep_done = sess.run(self._is_sweep_done_var)
if not self._is_initialized:
logging.info("SweepHook running cache init op.")
logging.info("SweepHook running init op.")
sess.run(self._init_op)
if is_sweep_done:
sess.run(self._switch_op)
is_row_sweep = sess.run(self._is_row_sweep_var)
if is_sweep_done or not self._is_initialized:
logging.info("SweepHook running sweep prep ops.")
row_sweep = sess.run(self._is_row_sweep_var)
prep_ops = self._row_prep_ops if row_sweep else self._col_prep_ops
logging.info("SweepHook running prep ops for the {} sweep.".format(
"row" if is_row_sweep else "col"))
prep_ops = self._row_prep_ops if is_row_sweep else self._col_prep_ops
for prep_op in prep_ops:
sess.run(prep_op)
self._is_initialized = True
# Requests running `self._update_op` jointly with the training op.
logging.info("Next fit step starting.")
return session_run_hook.SessionRunArgs(fetches=[self._update_op])
def after_run(self, run_context, run_values):
logging.info("Fit step done.")
return session_run_hook.SessionRunArgs(
fetches=[self._row_train_op if is_row_sweep else self._col_train_op])
class _StopAtSweepHook(session_run_hook.SessionRunHook):
@ -246,6 +134,9 @@ def _wals_factorization_model_function(features, labels, mode, params):
Returns:
A ModelFnOps object.
Raises:
ValueError: If `mode` is not recognized.
"""
assert labels is None
use_factors_weights_cache = (params["use_factors_weights_cache_for_training"]
@ -269,86 +160,156 @@ def _wals_factorization_model_function(features, labels, mode, params):
use_gramian_cache=use_gramian_cache)
# Get input rows and cols. We either update rows or columns depending on
# the value of row_sweep, which is maintained using a session hook
# the value of row_sweep, which is maintained using a session hook.
input_rows = features[WALSMatrixFactorization.INPUT_ROWS]
input_cols = features[WALSMatrixFactorization.INPUT_COLS]
input_row_indices, _ = array_ops.unique(input_rows.indices[:, 0])
input_col_indices, _ = array_ops.unique(input_cols.indices[:, 0])
# Train ops, controlled using the SweepHook
# We need to run the following ops:
# Before a row sweep:
# row_update_prep_gramian_op
# initialize_row_update_op
# During a row sweep:
# update_row_factors_op
# Before a col sweep:
# col_update_prep_gramian_op
# initialize_col_update_op
# During a col sweep:
# update_col_factors_op
# TRAIN mode:
if mode == model_fn.ModeKeys.TRAIN:
# Training consists of the folowing ops (controlled using a SweepHook).
# Before a row sweep:
# row_update_prep_gramian_op
# initialize_row_update_op
# During a row sweep:
# update_row_factors_op
# Before a col sweep:
# col_update_prep_gramian_op
# initialize_col_update_op
# During a col sweep:
# update_col_factors_op
is_row_sweep_var = variable_scope.variable(
True,
trainable=False,
name="is_row_sweep",
collections=[ops.GraphKeys.GLOBAL_VARIABLES])
completed_sweeps_var = variable_scope.variable(
0,
trainable=False,
name=WALSMatrixFactorization.COMPLETED_SWEEPS,
collections=[ops.GraphKeys.GLOBAL_VARIABLES])
is_row_sweep_var = variable_scope.variable(
True,
trainable=False,
name="is_row_sweep",
collections=[ops.GraphKeys.GLOBAL_VARIABLES])
is_sweep_done_var = variable_scope.variable(
False,
trainable=False,
name="is_sweep_done",
collections=[ops.GraphKeys.GLOBAL_VARIABLES])
completed_sweeps_var = variable_scope.variable(
0,
trainable=False,
name=WALSMatrixFactorization.COMPLETED_SWEEPS,
collections=[ops.GraphKeys.GLOBAL_VARIABLES])
loss_var = variable_scope.variable(
0.,
trainable=False,
name=WALSMatrixFactorization.LOSS,
collections=[ops.GraphKeys.GLOBAL_VARIABLES])
# The root weighted squared error =
# \sqrt( \sum_{i,j} w_ij * (a_ij - r_ij)^2 / \sum_{i,j} w_ij )
rwse_var = variable_scope.variable(
0.,
trainable=False,
name=WALSMatrixFactorization.RWSE,
collections=[ops.GraphKeys.GLOBAL_VARIABLES])
# The row sweep is determined by is_row_sweep_var (controlled by the
# sweep_hook) in TRAIN mode, and manually in EVAL mode.
is_row_sweep = (features[WALSMatrixFactorization.PROJECT_ROW]
if mode == model_fn.ModeKeys.EVAL else is_row_sweep_var)
summary.scalar("loss", loss_var)
summary.scalar("root_weighted_squared_error", rwse_var)
summary.scalar("completed_sweeps", completed_sweeps_var)
def update_row_factors():
return model.update_row_factors(sp_input=input_rows, transpose_input=False)
# Increments global step.
global_step = training_util.get_global_step()
if global_step:
global_step_incr_op = state_ops.assign_add(
global_step, 1, name="global_step_incr").op
else:
global_step_incr_op = control_flow_ops.no_op()
def update_col_factors():
return model.update_col_factors(sp_input=input_cols, transpose_input=True)
def create_axis_ops(sp_input, num_items, update_fn, axis_name):
"""Creates book-keeping and training ops for a given axis.
(_, train_op,
unregularized_loss, regularization, sum_weights) = control_flow_ops.cond(
is_row_sweep, update_row_factors, update_col_factors)
loss = unregularized_loss + regularization
root_weighted_squared_error = math_ops.sqrt(unregularized_loss / sum_weights)
Args:
sp_input: A SparseTensor corresponding to the row or column batch.
num_items: An integer, the total number of items of this axis.
update_fn: A function that takes one argument (`sp_input`), and that
returns a tuple of
* new_factors: A flot Tensor of the factor values after update.
* update_op: a TensorFlow op which updates the factors.
* loss: A float Tensor, the unregularized loss.
* reg_loss: A float Tensor, the regularization loss.
* sum_weights: A float Tensor, the sum of factor weights.
axis_name: A string that specifies the name of the axis.
row_prep_ops = [
model.row_update_prep_gramian_op, model.initialize_row_update_op
]
col_prep_ops = [
model.col_update_prep_gramian_op, model.initialize_col_update_op
]
init_ops = [model.worker_init]
Returns:
A tuple consisting of:
* reset_processed_items_op: A TensorFlow op, to be run before the
beginning of any sweep. It marks all items as not-processed.
* axis_train_op: A Tensorflow op, to be run during this axis' sweeps.
"""
processed_items_init = array_ops.fill(dims=[num_items], value=False)
with ops.colocate_with(processed_items_init):
processed_items = variable_scope.variable(
processed_items_init,
collections=[ops.GraphKeys.GLOBAL_VARIABLES],
trainable=False,
name="processed_" + axis_name)
reset_processed_items_op = state_ops.assign(
processed_items, processed_items_init,
name="reset_processed_" + axis_name)
_, update_op, loss, reg, sum_weights = update_fn(sp_input)
input_indices = sp_input.indices[:, 0]
with ops.control_dependencies([
update_op,
state_ops.assign(loss_var, loss + reg),
state_ops.assign(rwse_var, math_ops.sqrt(loss / sum_weights))]):
with ops.colocate_with(processed_items):
update_processed_items = state_ops.scatter_update(
processed_items,
input_indices,
array_ops.ones_like(input_indices, dtype=dtypes.bool),
name="update_processed_{}_indices".format(axis_name))
with ops.control_dependencies([update_processed_items]):
is_sweep_done = math_ops.reduce_all(processed_items)
axis_train_op = control_flow_ops.group(
global_step_incr_op,
state_ops.assign(is_sweep_done_var, is_sweep_done),
state_ops.assign_add(
completed_sweeps_var,
math_ops.cast(is_sweep_done, dtypes.int32)),
name="{}_sweep_train_op".format(axis_name))
return reset_processed_items_op, axis_train_op
sweep_hook = _SweepHook(
is_row_sweep_var,
[train_op, loss],
params["num_rows"],
params["num_cols"],
input_row_indices,
input_col_indices,
row_prep_ops,
col_prep_ops,
init_ops,
completed_sweeps_var)
training_hooks = [sweep_hook]
if max_sweeps is not None:
training_hooks.append(_StopAtSweepHook(max_sweeps))
reset_processed_rows_op, row_train_op = create_axis_ops(
input_rows,
params["num_rows"],
lambda x: model.update_row_factors(sp_input=x, transpose_input=False),
"rows")
reset_processed_cols_op, col_train_op = create_axis_ops(
input_cols,
params["num_cols"],
lambda x: model.update_col_factors(sp_input=x, transpose_input=True),
"cols")
switch_op = control_flow_ops.group(
state_ops.assign(
is_row_sweep_var, math_ops.logical_not(is_row_sweep_var)),
reset_processed_rows_op,
reset_processed_cols_op,
name="sweep_switch_op")
row_prep_ops = [
model.row_update_prep_gramian_op, model.initialize_row_update_op]
col_prep_ops = [
model.col_update_prep_gramian_op, model.initialize_col_update_op]
init_op = model.worker_init
sweep_hook = _SweepHook(
is_row_sweep_var, is_sweep_done_var, init_op,
row_prep_ops, col_prep_ops, row_train_op, col_train_op, switch_op)
training_hooks = [sweep_hook]
if max_sweeps is not None:
training_hooks.append(_StopAtSweepHook(max_sweeps))
# The root weighted squared error =
# \sqrt( \sum_{i,j} w_ij * (a_ij - r_ij)^2 / \sum_{i,j} w_ij )
summary.scalar("loss", loss) # the estimated total training loss
summary.scalar("root_weighted_squared_error", root_weighted_squared_error)
summary.scalar("completed_sweeps", completed_sweeps_var)
return model_fn.ModelFnOps(
mode=model_fn.ModeKeys.TRAIN,
predictions={},
loss=loss_var,
eval_metric_ops={},
train_op=control_flow_ops.no_op(),
training_hooks=training_hooks)
# Prediction ops (only return predictions in INFER mode)
predictions = {}
if mode == model_fn.ModeKeys.INFER:
project_row = features[WALSMatrixFactorization.PROJECT_ROW]
# INFER mode
elif mode == model_fn.ModeKeys.INFER:
projection_weights = features.get(
WALSMatrixFactorization.PROJECTION_WEIGHTS)
@ -364,17 +325,45 @@ def _wals_factorization_model_function(features, labels, mode, params):
projection_weights=projection_weights,
transpose_input=True)
predictions[WALSMatrixFactorization.PROJECTION_RESULT] = (
control_flow_ops.cond(project_row, get_row_projection,
get_col_projection))
predictions = {
WALSMatrixFactorization.PROJECTION_RESULT: control_flow_ops.cond(
features[WALSMatrixFactorization.PROJECT_ROW],
get_row_projection,
get_col_projection)
}
return model_fn.ModelFnOps(
mode=mode,
predictions=predictions,
loss=loss,
eval_metric_ops={},
train_op=train_op,
training_hooks=training_hooks)
return model_fn.ModelFnOps(
mode=model_fn.ModeKeys.INFER,
predictions=predictions,
loss=None,
eval_metric_ops={},
train_op=control_flow_ops.no_op(),
training_hooks=[])
# EVAL mode
elif mode == model_fn.ModeKeys.EVAL:
def get_row_loss():
_, _, loss, reg, _ = model.update_row_factors(
sp_input=input_rows, transpose_input=False)
return loss + reg
def get_col_loss():
_, _, loss, reg, _ = model.update_col_factors(
sp_input=input_cols, transpose_input=True)
return loss + reg
loss = control_flow_ops.cond(
features[WALSMatrixFactorization.PROJECT_ROW],
get_row_loss,
get_col_loss)
return model_fn.ModelFnOps(
mode=model_fn.ModeKeys.EVAL,
predictions={},
loss=loss,
eval_metric_ops={},
train_op=control_flow_ops.no_op(),
training_hooks=[])
else:
raise ValueError("mode=%s is not recognized." % str(mode))
class WALSMatrixFactorization(estimator.Estimator):
@ -452,6 +441,10 @@ class WALSMatrixFactorization(estimator.Estimator):
PROJECTION_RESULT = "projection"
# Name of the completed_sweeps variable
COMPLETED_SWEEPS = "completed_sweeps"
# Name of the loss variable
LOSS = "WALS_loss"
# Name of the Root Weighted Squared Error variable
RWSE = "WALS_RWSE"
def __init__(self,
num_rows,

View File

@ -417,73 +417,67 @@ class WALSMatrixFactorizationUnsupportedTest(test.TestCase):
class SweepHookTest(test.TestCase):
def setUp(self):
self._num_rows = 5
self._num_cols = 7
self._train_op = control_flow_ops.no_op()
self._row_prep_done = variables.Variable(False)
self._col_prep_done = variables.Variable(False)
self._init_done = variables.Variable(False)
self._row_prep_ops = [state_ops.assign(self._row_prep_done, True)]
self._col_prep_ops = [state_ops.assign(self._col_prep_done, True)]
self._init_ops = [state_ops.assign(self._init_done, True)]
self._input_row_indices_ph = array_ops.placeholder(dtypes.int64)
self._input_col_indices_ph = array_ops.placeholder(dtypes.int64)
def test_sweeps(self):
def ind_feed(row_indices, col_indices):
return {
self._input_row_indices_ph: row_indices,
self._input_col_indices_ph: col_indices
}
is_row_sweep_var = variables.Variable(True)
is_sweep_done_var = variables.Variable(False)
init_done = variables.Variable(False)
row_prep_done = variables.Variable(False)
col_prep_done = variables.Variable(False)
row_train_done = variables.Variable(False)
col_train_done = variables.Variable(False)
init_op = state_ops.assign(init_done, True)
row_prep_op = state_ops.assign(row_prep_done, True)
col_prep_op = state_ops.assign(col_prep_done, True)
row_train_op = state_ops.assign(row_train_done, True)
col_train_op = state_ops.assign(col_train_done, True)
train_op = control_flow_ops.no_op()
switch_op = control_flow_ops.group(
state_ops.assign(is_sweep_done_var, False),
state_ops.assign(is_row_sweep_var,
math_ops.logical_not(is_row_sweep_var)))
mark_sweep_done = state_ops.assign(is_sweep_done_var, True)
with self.test_session() as sess:
is_row_sweep_var = variables.Variable(True)
completed_sweeps_var = variables.Variable(0)
sweep_hook = wals_lib._SweepHook(
is_row_sweep_var,
[self._train_op],
self._num_rows,
self._num_cols,
self._input_row_indices_ph,
self._input_col_indices_ph,
self._row_prep_ops,
self._col_prep_ops,
self._init_ops,
completed_sweeps_var)
is_sweep_done_var,
init_op,
[row_prep_op],
[col_prep_op],
row_train_op,
col_train_op,
switch_op)
mon_sess = monitored_session._HookedSession(sess, [sweep_hook])
sess.run([variables.global_variables_initializer()])
# Init ops should run before the first run. Row sweep not completed.
mon_sess.run(self._train_op, ind_feed([0, 1, 2], []))
self.assertTrue(sess.run(self._init_done),
msg='init ops not run by the sweep_hook')
self.assertTrue(sess.run(self._row_prep_done),
msg='row_prep not run by the sweep_hook')
self.assertTrue(sess.run(is_row_sweep_var),
msg='Row sweep is not complete but is_row_sweep is '
'False.')
# Row sweep completed.
mon_sess.run(self._train_op, ind_feed([3, 4], [0, 1, 2, 3, 4, 5, 6]))
self.assertTrue(sess.run(completed_sweeps_var) == 1,
msg='Completed sweeps should be equal to 1.')
self.assertTrue(sess.run(sweep_hook._is_sweep_done_var),
msg='Sweep is complete but is_sweep_done is False.')
# Col init ops should run. Col sweep not completed.
mon_sess.run(self._train_op, ind_feed([], [0, 1, 2, 3, 4]))
self.assertTrue(sess.run(self._col_prep_done),
msg='col_prep not run by the sweep_hook')
self.assertFalse(sess.run(is_row_sweep_var),
msg='Col sweep is not complete but is_row_sweep is '
'True.')
self.assertFalse(sess.run(sweep_hook._is_sweep_done_var),
msg='Sweep is not complete but is_sweep_done is True.')
# Col sweep completed.
mon_sess.run(self._train_op, ind_feed([], [4, 5, 6]))
self.assertTrue(sess.run(sweep_hook._is_sweep_done_var),
msg='Sweep is complete but is_sweep_done is False.')
self.assertTrue(sess.run(completed_sweeps_var) == 2,
msg='Completed sweeps should be equal to 2.')
# Row sweep.
mon_sess.run(train_op)
self.assertTrue(sess.run(init_done),
msg='init op not run by the Sweephook')
self.assertTrue(sess.run(row_prep_done),
msg='row_prep_op not run by the SweepHook')
self.assertTrue(sess.run(row_train_done),
msg='row_train_op not run by the SweepHook')
self.assertTrue(
sess.run(is_row_sweep_var),
msg='Row sweep is not complete but is_row_sweep_var is False.')
# Col sweep.
mon_sess.run(mark_sweep_done)
mon_sess.run(train_op)
self.assertTrue(sess.run(col_prep_done),
msg='col_prep_op not run by the SweepHook')
self.assertTrue(sess.run(col_train_done),
msg='col_train_op not run by the SweepHook')
self.assertFalse(
sess.run(is_row_sweep_var),
msg='Col sweep is not complete but is_row_sweep_var is True.')
# Row sweep.
mon_sess.run(mark_sweep_done)
mon_sess.run(train_op)
self.assertTrue(
sess.run(is_row_sweep_var),
msg='Col sweep is complete but is_row_sweep_var is False.')
class StopAtSweepHookTest(test.TestCase):

View File

@ -224,7 +224,8 @@ def transform(images, transforms, interpolation="NEAREST", name=None):
`(x, y)` to a transformed *input* point
`(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`,
where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to
the transform mapping input points to output points.
the transform mapping input points to output points. Note that gradients
are not backpropagated into transformation parameters.
interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR".
Returns:

View File

@ -68,6 +68,7 @@ py_test(
srcs = ["layer_collection_test.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/kfac/python/ops:fisher_blocks",
"//tensorflow/contrib/kfac/python/ops:fisher_factors",
"//tensorflow/contrib/kfac/python/ops:layer_collection",
"//tensorflow/python:array_ops",
@ -75,6 +76,7 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:linalg_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:random_ops",
"//tensorflow/python:random_seed",
"//tensorflow/python:variable_scope",

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.kfac.python.ops import fisher_blocks
from tensorflow.contrib.kfac.python.ops import fisher_factors
from tensorflow.contrib.kfac.python.ops import layer_collection
from tensorflow.python.framework import dtypes
@ -25,6 +26,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
@ -105,8 +107,10 @@ class LayerCollectionTest(test.TestCase):
array_ops.constant(4), [1, 1, 1, 1], 'SAME',
array_ops.ones((1, 1, 1, 1)), array_ops.constant(3))
lc.register_conv2d(
array_ops.constant(4), [1, 1, 1, 1], 'SAME',
array_ops.ones((1, 1, 1, 1)), array_ops.constant(3),
array_ops.constant(4), [1, 1, 1, 1],
'SAME',
array_ops.ones((1, 1, 1, 1)),
array_ops.constant(3),
approx=layer_collection.APPROX_DIAGONAL_NAME)
lc.register_generic(
array_ops.constant(5), 16, approx=layer_collection.APPROX_FULL_NAME)
@ -122,8 +126,8 @@ class LayerCollectionTest(test.TestCase):
random_seed.set_random_seed(200)
lc = layer_collection.LayerCollection()
key = array_ops.constant(1)
lc.register_fully_connected(key,
array_ops.constant(2), array_ops.constant(3))
lc.register_fully_connected(key, array_ops.constant(2),
array_ops.constant(3))
with self.assertRaises(ValueError):
lc.register_generic(key, 16)
@ -191,8 +195,8 @@ class LayerCollectionTest(test.TestCase):
lc.register_block((x, y), MockFisherBlock('foo'))
self.assertEqual(
set([MockFisherBlock('2'), MockFisherBlock('foo')]),
set(lc.get_blocks()))
set([MockFisherBlock('2'), MockFisherBlock('foo')]), set(
lc.get_blocks()))
def testRegisterTupleVarSomeRegisteredInOtherTuples(self):
x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
@ -464,6 +468,66 @@ class LayerCollectionTest(test.TestCase):
use_count_map = lc.get_use_count_map()
self.assertDictEqual({'a': 4, 'b': 2, 'c': 4}, use_count_map)
def testIdentifyLinkedParametersSomeRegisteredInOtherTuples(self):
x = variable_scope.get_variable('x', shape=())
y = variable_scope.get_variable('y', shape=())
z = variable_scope.get_variable('z', shape=())
lc = layer_collection.LayerCollection()
lc.define_linked_parameters((x, y))
with self.assertRaises(ValueError):
lc.define_linked_parameters((x, z))
def testIdentifySubsetPreviouslyRegisteredTensor(self):
x = variable_scope.get_variable('x', shape=())
y = variable_scope.get_variable('y', shape=())
lc = layer_collection.LayerCollection()
lc.define_linked_parameters((x, y))
with self.assertRaises(ValueError):
lc.define_linked_parameters(x)
def testSpecifyApproximation(self):
w_0 = variable_scope.get_variable('w_0', [10, 10])
w_1 = variable_scope.get_variable('w_1', [10, 10])
b_0 = variable_scope.get_variable('b_0', [10])
b_1 = variable_scope.get_variable('b_1', [10])
x_0 = array_ops.placeholder(dtypes.float32, shape=(32, 10))
x_1 = array_ops.placeholder(dtypes.float32, shape=(32, 10))
pre_bias_0 = math_ops.matmul(x_0, w_0)
pre_bias_1 = math_ops.matmul(x_1, w_1)
# Build the fully connected layers in the graph.
pre_bias_0 + b_0 # pylint: disable=pointless-statement
pre_bias_1 + b_1 # pylint: disable=pointless-statement
lc = layer_collection.LayerCollection()
lc.define_linked_parameters(
w_0, approximation=layer_collection.APPROX_DIAGONAL_NAME)
lc.define_linked_parameters(
w_1, approximation=layer_collection.APPROX_DIAGONAL_NAME)
lc.define_linked_parameters(
b_0, approximation=layer_collection.APPROX_FULL_NAME)
lc.define_linked_parameters(
b_1, approximation=layer_collection.APPROX_FULL_NAME)
lc.register_fully_connected(w_0, x_0, pre_bias_0)
lc.register_fully_connected(
w_1, x_1, pre_bias_1, approx=layer_collection.APPROX_KRONECKER_NAME)
self.assertIsInstance(lc.fisher_blocks[w_0],
fisher_blocks.FullyConnectedDiagonalFB)
self.assertIsInstance(lc.fisher_blocks[w_1],
fisher_blocks.FullyConnectedKFACBasicFB)
lc.register_generic(b_0, batch_size=1)
lc.register_generic(
b_1, batch_size=1, approx=layer_collection.APPROX_DIAGONAL_NAME)
self.assertIsInstance(lc.fisher_blocks[b_0], fisher_blocks.FullFB)
self.assertIsInstance(lc.fisher_blocks[b_1], fisher_blocks.NaiveDiagonalFB)
if __name__ == '__main__':
test.main()

View File

@ -38,12 +38,26 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
# Names for various approximations that can be requested for Fisher blocks.
APPROX_KRONECKER_NAME = "kron"
APPROX_DIAGONAL_NAME = "diagonal"
APPROX_FULL_NAME = "full"
_GENERIC_APPROX_TO_BLOCK_TYPES = {
APPROX_FULL_NAME: fb.FullFB,
APPROX_DIAGONAL_NAME: fb.NaiveDiagonalFB,
}
_FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES = {
APPROX_KRONECKER_NAME: fb.FullyConnectedKFACBasicFB,
APPROX_DIAGONAL_NAME: fb.FullyConnectedDiagonalFB,
}
_CONV2D_APPROX_TO_BLOCK_TYPES = {
APPROX_KRONECKER_NAME: fb.ConvKFCBasicFB,
APPROX_DIAGONAL_NAME: fb.ConvDiagonalFB,
}
# Possible value for 'reuse' keyword argument. Sets 'reuse' to
# tf.get_variable_scope().reuse.
VARIABLE_SCOPE = "VARIABLE_SCOPE"
@ -51,6 +65,14 @@ VARIABLE_SCOPE = "VARIABLE_SCOPE"
# TODO(jamesmartens): need to add find_canonical_output back into this somewhere
def ensure_sequence(obj):
"""If `obj` isn't a tuple or list, return a tuple containing `obj`."""
if isinstance(obj, (tuple, list)):
return obj
else:
return (obj,)
class LayerParametersDict(OrderedDict):
"""An OrderedDict where keys are Tensors or tuples of Tensors.
@ -110,9 +132,14 @@ class LayerCollection(object):
def __init__(self, graph=None, name="LayerCollection"):
self.fisher_blocks = LayerParametersDict()
self.fisher_factors = OrderedDict()
self._linked_parameters = dict(
) # dict mapping sets of variables to optionally specified approximations.
self._graph = graph or ops.get_default_graph()
self._loss_dict = {} # {str: LossFunction}
self._subgraph = None
self._default_generic_approximation = APPROX_FULL_NAME
self._default_fully_connected_approximation = APPROX_KRONECKER_NAME
self._default_convolution_2d_approximation = APPROX_KRONECKER_NAME
with variable_scope.variable_scope(None, default_name=name) as scope:
self._var_scope = scope.name
@ -122,6 +149,70 @@ class LayerCollection(object):
"""LossFunctions registered with this LayerCollection."""
return list(self._loss_dict.values())
def is_variable_registered(self, variable):
"""Checks whether the variable has already been registered.
Args:
variable: A single variable or tensor.
Returns:
True if the variable has been registered either by itself or as part of a
tuple.
"""
return any([
variable in key if isinstance(key, (tuple, list)) else variable == key
for key in self.fisher_blocks.keys()
])
@property
def linked_parameters(self):
"""Groups of parameters with an optionally specified approximation.
Linked parameters can be added using `define_linked_parameters`.
If an approximation is specified, then this approximation will be used
when registering a layer with exactly these parameters, unless an
approximation is specified when calling the registration function.
Returns:
A `dict` mapping tuples of parameters to an optional string.
"""
return self._linked_parameters
@property
def default_generic_approximation(self):
return self._default_generic_approximation
@default_generic_approximation.setter
def default_generic_approximation(self, value):
if value not in _GENERIC_APPROX_TO_BLOCK_TYPES:
raise ValueError(
"{} is not a valid approximation for generic variables.".format(
value))
self._default_generic_approximation = value
@property
def default_fully_connected_approximation(self):
return self._default_fully_connected_approximation
@default_fully_connected_approximation.setter
def default_fully_connected_approximation(self, value):
if value not in _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES:
raise ValueError(
"{} is not a valid approximation for fully connected layers.".format(
value))
self._default_fully_connected_approximation = value
@property
def default_conv2d_approximation(self):
return self._default_convolution_2d_approximation
@default_conv2d_approximation.setter
def default_conv2d_approximation(self, value):
if value not in _CONV2D_APPROX_TO_BLOCK_TYPES:
raise ValueError(
"{} is not a valid approximation for 2d convolutional layers.".format(
value))
self._default_convolution_2d_approximation = value
def register_block(self, layer_key, fisher_block, reuse=VARIABLE_SCOPE):
"""Validates and registers the layer_key associated with the fisher_block.
@ -187,7 +278,8 @@ class LayerCollection(object):
# Find all keys that are either supersets or subsets of 'layer_key'.
inclusions = {
fisher_elt
for layer_elt in layer_key for fisher_elt in self.fisher_blocks
for layer_elt in layer_key
for fisher_elt in self.fisher_blocks
if self._equal_or_subset(layer_elt, fisher_elt)
}
@ -294,6 +386,49 @@ class LayerCollection(object):
def subgraph(self):
return self._subgraph
def define_linked_parameters(self, params, approximation=None):
"""Identify a set of parameters that should be grouped together.
During automatic graph scanning, any matches containing variables that have
been identified as part of a linked group will be filtered out unless
the match parameters are exactly equal to the ones specified in the linked
group.
Args:
params: A variable, or a tuple or list of variables. The variables
to be linked.
approximation: Optional string specifying the type of approximation to use
for these variables. If unspecified, this layer collection's default
approximation for the layer type will be used.
Raises:
ValueError: If the parameters were already registered in a layer or
identified as part of an incompatible group.
"""
params = frozenset(ensure_sequence(params))
# Check if any of the variables in 'params' is already in
# 'self.fisher_blocks.keys()'.
for registered_params, fisher_block in self.fisher_blocks.items():
registered_params_set = set(ensure_sequence(registered_params))
for variable in params:
if (variable in registered_params_set and
params != registered_params_set):
raise ValueError(
"Can't link parameters {}, variable {} was already registered in "
"group {} with layer {}".format(params, variable,
registered_params, fisher_block))
# Check if any of the variables in 'params' is already in
# 'self.linked_parameters'.
for variable in params:
for other_linked_params in self.linked_parameters:
if variable in other_linked_params:
raise ValueError("Can't link parameters {}, variable {} was already "
"linked in group {}.".format(params, variable,
other_linked_params))
self._linked_parameters[params] = approximation
def create_subgraph(self):
if not self.losses:
raise ValueError("Must have at least one registered loss.")
@ -307,11 +442,19 @@ class LayerCollection(object):
return math_ops.add_n(
tuple(loss.evaluate_on_sample() for loss in self.losses))
def _get_linked_approx(self, params):
"""If params were linked, return their specified approximation."""
params_set = frozenset(ensure_sequence(params))
if params_set in self.linked_parameters:
return self.linked_parameters[params_set]
else:
return None
def register_fully_connected(self,
params,
inputs,
outputs,
approx=APPROX_KRONECKER_NAME,
approx=None,
reuse=VARIABLE_SCOPE):
"""Registers a fully connnected layer.
@ -332,15 +475,15 @@ class LayerCollection(object):
KeyError: If reuse == True but no FisherBlock found for 'params'.
ValueError: If reuse == True and FisherBlock found but of the wrong type.
"""
approx_to_block_types = {
APPROX_KRONECKER_NAME: fb.FullyConnectedKFACBasicFB,
APPROX_DIAGONAL_NAME: fb.FullyConnectedDiagonalFB,
}
if approx is None:
approx = self._get_linked_approx(params)
if approx is None:
approx = self.default_fully_connected_approximation
if approx not in approx_to_block_types:
if approx not in _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES:
raise ValueError("Bad value {} for approx.".format(approx))
block_type = approx_to_block_types[approx]
block_type = _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES[approx]
has_bias = isinstance(params, (tuple, list))
block = self.register_block(params, block_type(self, has_bias), reuse=reuse)
@ -352,7 +495,7 @@ class LayerCollection(object):
padding,
inputs,
outputs,
approx=APPROX_KRONECKER_NAME,
approx=None,
reuse=VARIABLE_SCOPE):
"""Registers a convolutional layer.
@ -377,15 +520,16 @@ class LayerCollection(object):
KeyError: If reuse == True but no FisherBlock found for 'params'.
ValueError: If reuse == True and FisherBlock found but of the wrong type.
"""
approx_to_block_types = {
APPROX_KRONECKER_NAME: fb.ConvKFCBasicFB,
APPROX_DIAGONAL_NAME: fb.ConvDiagonalFB,
}
if approx not in approx_to_block_types:
if approx is None:
approx = self._get_linked_approx(params)
if approx is None:
approx = self.default_conv2d_approximation
if approx not in _CONV2D_APPROX_TO_BLOCK_TYPES:
raise ValueError("Bad value {} for approx.".format(approx))
block_type = approx_to_block_types[approx]
block_type = _CONV2D_APPROX_TO_BLOCK_TYPES[approx]
block = self.register_block(
params, block_type(self, params, strides, padding), reuse=reuse)
block.register_additional_minibatch(inputs, outputs)
@ -393,7 +537,7 @@ class LayerCollection(object):
def register_generic(self,
params,
batch_size,
approx=APPROX_DIAGONAL_NAME,
approx=None,
reuse=VARIABLE_SCOPE):
"""Registers a generic layer.
@ -413,15 +557,16 @@ class LayerCollection(object):
KeyError: If reuse == True but no FisherBlock found for 'params'.
ValueError: If reuse == True and FisherBlock found but of the wrong type.
"""
approx_to_block_types = {
APPROX_FULL_NAME: fb.FullFB,
APPROX_DIAGONAL_NAME: fb.NaiveDiagonalFB,
}
if approx not in approx_to_block_types:
if approx is None:
approx = self._get_linked_approx(params)
if approx is None:
approx = self.default_generic_approximation
if approx not in _GENERIC_APPROX_TO_BLOCK_TYPES:
raise ValueError("Bad value {} for approx.".format(approx))
block_type = approx_to_block_types[approx]
block_type = _GENERIC_APPROX_TO_BLOCK_TYPES[approx]
block = self.register_block(params, block_type(self, params), reuse=reuse)
block.register_additional_minibatch(batch_size)
@ -560,10 +705,10 @@ class LayerCollection(object):
try:
hash(args)
except TypeError:
raise TypeError((
"Unable to use (cls, args) = ({}, {}) as a key in "
"LayerCollection.fisher_factors. The pair cannot be hashed."
).format(cls, args))
raise TypeError(
("Unable to use (cls, args) = ({}, {}) as a key in "
"LayerCollection.fisher_factors. The pair cannot be hashed.").format(
cls, args))
with variable_scope.variable_scope(self._var_scope):
return utils.setdefault(self.fisher_factors, (cls, args),

View File

@ -44,7 +44,8 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
momentum=0.,
momentum_type="regular",
norm_constraint=None,
name="KFAC",):
name="KFAC",
estimation_mode="gradients"):
"""Initializes the KFAC optimizer with the given settings.
Args:
@ -72,6 +73,10 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
specified value. May only be used with momentum type 'regular'.
(Default: None)
name: The name for this optimizer. (Default: 'KFAC')
estimation_mode: The type of estimator to use for the Fishers. Can be
'gradients', 'empirical', 'curvature_propagation', or 'exact'.
(Default: 'gradients'). See the doc-string for FisherEstimator for
more a more detailed description of these options.
Raises:
ValueError: If the momentum type is unsupported.
@ -86,7 +91,8 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
variables = tf_variables.trainable_variables()
self._fisher_est = est.FisherEstimator(variables, cov_ema_decay, damping,
layer_collection)
layer_collection,
estimation_mode=estimation_mode)
momentum_type = momentum_type.lower()
legal_momentum_types = ["regular", "adam", "qmodel"]

View File

@ -1403,7 +1403,8 @@ def dropout(inputs,
noise_shape=None,
is_training=True,
outputs_collections=None,
scope=None):
scope=None,
seed=None):
"""Returns a dropout op applied to the input.
With probability `keep_prob`, outputs the input element scaled up by
@ -1421,6 +1422,8 @@ def dropout(inputs,
Otherwise, inputs is returned.
outputs_collections: Collection to add the outputs.
scope: Optional scope for name_scope.
seed: A Python integer. Used to create random seeds. See
@{tf.set_random_seed} for behavior.
Returns:
A tensor representing the output of the operation.
@ -1430,6 +1433,7 @@ def dropout(inputs,
inputs = ops.convert_to_tensor(inputs)
layer = core_layers.Dropout(rate=1 - keep_prob,
noise_shape=noise_shape,
seed=seed,
name=sc.name,
_scope=sc)
outputs = layer.apply(inputs, training=is_training)

View File

@ -1345,11 +1345,20 @@ class DropoutTest(test.TestCase):
num_elem_initial = math_ops.reduce_mean(math_ops.to_float(images > 0))
output = _layers.dropout(images)
num_elem = math_ops.reduce_mean(math_ops.to_float(output > 0))
sess.run(variables_lib.global_variables_initializer())
num_elem, num_elem_initial = sess.run([num_elem, num_elem_initial])
self.assertLess(num_elem, num_elem_initial / 2 + 0.1)
self.assertGreater(num_elem, num_elem_initial / 2 - 0.1)
def testDropoutSeed(self):
"""Test that providing the same seed produces the same result."""
height, width = 10, 10
with self.test_session() as sess:
images = random_ops.random_uniform(
(5, height, width, 3), seed=1, name='images')
output1 = _layers.dropout(images, seed=1)
output2 = _layers.dropout(images, seed=1)
self.assertAllEqual(*sess.run([output1, output2]))
def testCreateDropoutNoTraining(self):
height, width = 3, 3
with self.test_session() as sess:
@ -1358,7 +1367,6 @@ class DropoutTest(test.TestCase):
num_elem_initial = math_ops.reduce_mean(math_ops.to_float(images > 0))
output = _layers.dropout(images, is_training=False)
num_elem = math_ops.reduce_mean(math_ops.to_float(output > 0))
sess.run(variables_lib.global_variables_initializer())
num_elem, num_elem_initial = sess.run([num_elem, num_elem_initial])
self.assertEqual(num_elem, num_elem_initial)
outputs, inputs = sess.run([output, images])

View File

@ -33,6 +33,7 @@ from __future__ import division
from __future__ import print_function
import os
import tempfile
import time
from tensorflow.contrib.layers.python.layers import feature_column
@ -644,18 +645,22 @@ def make_best_model_export_strategy(serving_input_fn,
# TODO(b/67013778): Revisit this approach when corresponding changes to
# TF Core are finalized.
def extend_export_strategy(base_export_strategy, post_export_fn,
post_export_name):
def extend_export_strategy(base_export_strategy,
post_export_fn,
post_export_name=None):
"""Extend ExportStrategy, calling post_export_fn after export.
Args:
base_export_strategy: An ExportStrategy that can be passed to the Experiment
constructor.
post_export_fn: A user-specified function to call after exporting the
SavedModel. Takes the export directory as an argument, and returns
a string path to a (potentially different) SavedModel.
SavedModel. Takes two arguments - the path to the SavedModel exported by
base_export_strategy and the directory where to export the SavedModel
modified by the post_export_fn. Returns the path to the exported
SavedModel.
post_export_name: The directory name under the export base directory where
SavedModels generated by the post_export_fn will be written.
SavedModels generated by the post_export_fn will be written. If None, the
directory name of base_export_strategy is used.
Returns:
An ExportStrategy that can be passed to the Experiment constructor.
@ -675,12 +680,24 @@ def extend_export_strategy(base_export_strategy, post_export_fn,
Raises:
ValueError: If `estimator` is a ${tf.estimator.Estimator} instance
and `default_output_alternative_key` was specified.
and `default_output_alternative_key` was specified or if post_export_fn
does not return a valid directory.
"""
export_dir = base_export_strategy.export(estimator, export_dir_base,
checkpoint_path)
if post_export_fn:
export_dir = post_export_fn(export_dir)
return export_dir
tmp_base_export_dir = tempfile.mkdtemp()
tmp_base_export = base_export_strategy.export(
estimator, tmp_base_export_dir, checkpoint_path)
tmp_post_export_dir = tempfile.mkdtemp()
tmp_post_export = post_export_fn(tmp_base_export, tmp_post_export_dir)
return export_strategy.ExportStrategy(post_export_name, export_fn)
if not tmp_post_export.startswith(tmp_post_export_dir):
raise ValueError('post_export_fn must return a sub-directory of {}'
.format(tmp_post_export_dir))
export_relpath = os.path.relpath(tmp_post_export, tmp_post_export_dir)
gfile.Rename(
os.path.join(tmp_post_export_dir, export_relpath),
os.path.join(export_dir_base, export_relpath))
return os.path.join(export_dir_base, export_relpath)
name = post_export_name if post_export_name else base_export_strategy.name
return export_strategy.ExportStrategy(name, export_fn)

Some files were not shown because too many files have changed in this diff Show More