commit
9089ab5982
@ -375,6 +375,7 @@ config_setting(
|
||||
package_group(
|
||||
name = "internal",
|
||||
packages = [
|
||||
"//learning/meta_rank/...",
|
||||
"//tensorflow/...",
|
||||
"//tensorflow_fold/llgtm/...",
|
||||
],
|
||||
|
@ -39,21 +39,23 @@ static void AllocateFlags() {
|
||||
flags->tf_xla_min_cluster_size = 2;
|
||||
flags->tf_xla_max_cluster_size = std::numeric_limits<int32>::max();
|
||||
flags->tf_xla_clustering_debug = false;
|
||||
flag_list = new std::vector<Flag>({
|
||||
Flag("tf_xla_auto_jit", &flags->tf_xla_auto_jit,
|
||||
"Control compilation of operators into XLA computations on CPU and "
|
||||
"GPU devices. 0 = use ConfigProto setting; -1 = off; 1 = on for "
|
||||
"things very likely to be improved; 2 = on for everything. "
|
||||
"Experimental."),
|
||||
Flag("tf_xla_min_cluster_size", &flags->tf_xla_min_cluster_size,
|
||||
"Minimum number of operators in an XLA compilation. Ignored for "
|
||||
"operators placed on an XLA device or operators explicitly marked "
|
||||
"for compilation."),
|
||||
Flag("tf_xla_max_cluster_size", &flags->tf_xla_max_cluster_size,
|
||||
"Maximum number of operators in an XLA compilation."),
|
||||
Flag("tf_xla_clustering_debug", &flags->tf_xla_clustering_debug,
|
||||
"Dump graphs during XLA compilation."),
|
||||
});
|
||||
flags->tf_xla_cpu_global_jit = false;
|
||||
flag_list = new std::vector<Flag>(
|
||||
{Flag("tf_xla_auto_jit", &flags->tf_xla_auto_jit,
|
||||
"Control compilation of operators into XLA computations on CPU and "
|
||||
"GPU devices. 0 = use ConfigProto setting; -1 = off; 1 = on for "
|
||||
"things very likely to be improved; 2 = on for everything. "
|
||||
"Experimental."),
|
||||
Flag("tf_xla_min_cluster_size", &flags->tf_xla_min_cluster_size,
|
||||
"Minimum number of operators in an XLA compilation. Ignored for "
|
||||
"operators placed on an XLA device or operators explicitly marked "
|
||||
"for compilation."),
|
||||
Flag("tf_xla_max_cluster_size", &flags->tf_xla_max_cluster_size,
|
||||
"Maximum number of operators in an XLA compilation."),
|
||||
Flag("tf_xla_clustering_debug", &flags->tf_xla_clustering_debug,
|
||||
"Dump graphs during XLA compilation."),
|
||||
Flag("tf_xla_cpu_global_jit", &flags->tf_xla_cpu_global_jit,
|
||||
"Enables global JIT compilation for CPU via SessionOptions.")});
|
||||
xla::legacy_flags::ParseFlagsFromEnv(*flag_list);
|
||||
}
|
||||
|
||||
|
@ -46,6 +46,8 @@ typedef struct {
|
||||
int32 tf_xla_max_cluster_size; // Maximum number of operators in an XLA
|
||||
// compilation.
|
||||
bool tf_xla_clustering_debug; // Dump graphs during XLA compilation.
|
||||
bool tf_xla_cpu_global_jit; // Enables global JIT compilation for CPU
|
||||
// via SessionOptions.
|
||||
} MarkForCompilationPassFlags;
|
||||
|
||||
// Return a pointer to the MarkForCompilationPassFlags struct;
|
||||
|
@ -290,9 +290,11 @@ Status MarkForCompilationPass::Run(
|
||||
global_jit_level =
|
||||
static_cast<OptimizerOptions::GlobalJitLevel>(flags->tf_xla_auto_jit);
|
||||
}
|
||||
bool cpu_global_jit = flags->tf_xla_cpu_global_jit;
|
||||
const FunctionLibraryDefinition* fld = options.flib_def;
|
||||
auto is_compilable = [global_jit_level, fld](const Node* node,
|
||||
const DeviceType& device_type) {
|
||||
|
||||
auto is_compilable = [global_jit_level, cpu_global_jit, fld](
|
||||
const Node* node, const DeviceType& device_type) {
|
||||
const XlaOpRegistry::DeviceRegistration* registration;
|
||||
if (!XlaOpRegistry::GetCompilationDevice(device_type.type(),
|
||||
®istration)) {
|
||||
@ -315,7 +317,11 @@ Status MarkForCompilationPass::Run(
|
||||
if (status.ok()) return compile;
|
||||
|
||||
// Otherwise use the value of global_jit_level.
|
||||
return registration->enable_jit_by_default && global_jit_level > 0;
|
||||
// Ignore enable_jit_by_default if global jit compilation for CPU
|
||||
// is explicitly requested via tf_xla_cpu_global_jit flag
|
||||
bool ignore_registration = cpu_global_jit && device_type == DEVICE_CPU;
|
||||
return (ignore_registration || registration->enable_jit_by_default) &&
|
||||
global_jit_level > 0;
|
||||
};
|
||||
return RunImpl(options, is_compilable);
|
||||
}
|
||||
@ -556,6 +562,7 @@ Status MarkForCompilationPass::RunImpl(
|
||||
if (cluster_sizes[cluster] >= min_cluster_size || marked_for_compilation ||
|
||||
registration->requires_compilation) {
|
||||
string& name = cluster_names[cluster];
|
||||
|
||||
if (name.empty()) {
|
||||
name = strings::StrCat("cluster_", cluster_sequence_num++);
|
||||
}
|
||||
|
@ -179,7 +179,6 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/client:computation_builder",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
|
@ -731,11 +731,12 @@ string DebugString(const Graph& graph,
|
||||
FunctionalizeCond::ClusterHandle::Vector* clusters) {
|
||||
string ret = "digraph {\ncompound=true;labeljust=\"r\";ranksep=0.24\n";
|
||||
std::map<FunctionalizeCond::ClusterHandle, string> subgraphs;
|
||||
auto name = [](const Node* n) {
|
||||
return strings::StrCat(n->type_string(), "_", n->id());
|
||||
};
|
||||
for (Node* n : graph.nodes()) {
|
||||
if (n->IsOp()) {
|
||||
strings::StrAppend(&subgraphs[clusters->at(n).Get()], n->id(),
|
||||
" [label=\"", n->name(), "\"];\n");
|
||||
}
|
||||
strings::StrAppend(&subgraphs[clusters->at(n).Get()], n->id(), " [label=\"",
|
||||
name(n), "\"];\n");
|
||||
}
|
||||
for (auto kv : subgraphs) {
|
||||
strings::StrAppend(&ret, "subgraph cluster_", kv.first.ToString(), " {\n",
|
||||
@ -743,16 +744,11 @@ string DebugString(const Graph& graph,
|
||||
kv.first.ToString(), "\";\n", kv.second, "}\n");
|
||||
}
|
||||
for (Node* n : graph.nodes()) {
|
||||
if (!n->IsOp()) {
|
||||
continue;
|
||||
}
|
||||
for (Node* in : n->in_nodes()) {
|
||||
if (in->IsOp()) {
|
||||
strings::StrAppend(&ret, in->id(), " -> ", n->id(), ";\n");
|
||||
}
|
||||
strings::StrAppend(&ret, in->id(), " -> ", n->id(), ";\n");
|
||||
}
|
||||
}
|
||||
return strings::StrCat(ret, "}");
|
||||
return strings::StrCat(ret, "} // end");
|
||||
}
|
||||
|
||||
string DebugString(const FunctionalizeCond::ClusteredGraph& clustered_graph) {
|
||||
@ -761,16 +757,24 @@ string DebugString(const FunctionalizeCond::ClusteredGraph& clustered_graph) {
|
||||
return cluster.representative.ToString();
|
||||
};
|
||||
for (auto kv : clustered_graph) {
|
||||
strings::StrAppend(&ret, kv.first.ToString(), " [label=\"", name(kv.second),
|
||||
" (", kv.second.switch_nodes.size(), ", ",
|
||||
kv.second.merge_nodes.size(), ")\"];\n");
|
||||
if (!kv.second.switch_nodes.empty() || !kv.second.merge_nodes.empty()) {
|
||||
strings::StrAppend(
|
||||
&ret, kv.first.ToString(), " [label=\"", name(kv.second),
|
||||
kv.second.switch_nodes.empty()
|
||||
? ""
|
||||
: strings::StrCat(" switches=", kv.second.switch_nodes.size()),
|
||||
kv.second.merge_nodes.empty()
|
||||
? ""
|
||||
: strings::StrCat(" merges=", kv.second.merge_nodes.size()),
|
||||
"\"];\n");
|
||||
}
|
||||
}
|
||||
for (auto kv : clustered_graph) {
|
||||
for (auto in : kv.second.in_nodes) {
|
||||
strings::StrAppend(&ret, name(*in), " -> ", name(kv.second), ";\n");
|
||||
}
|
||||
}
|
||||
return strings::StrCat(ret, "}");
|
||||
return strings::StrCat(ret, "} // end");
|
||||
}
|
||||
|
||||
bool IsDeadSwitch(const Node* node) {
|
||||
@ -790,9 +794,6 @@ bool IsDeadSwitch(const Node* node) {
|
||||
|
||||
void FunctionalizeCond::CreateClusters() {
|
||||
for (Node* node : graph_->nodes()) {
|
||||
if (!node->IsOp()) {
|
||||
continue;
|
||||
}
|
||||
if (IsSwitch(node)) {
|
||||
switch_nodes_.insert(node);
|
||||
} else if (IsMerge(node)) {
|
||||
@ -825,6 +826,10 @@ void FunctionalizeCond::CreateClusters() {
|
||||
clusters_.at(node).Merge(&clusters_.at(in));
|
||||
}
|
||||
}
|
||||
// Group all source clusters together.
|
||||
if (node->IsSource() || node->in_edges().empty()) {
|
||||
clusters_.at(node).Merge(&clusters_.at(ClusterHandle(Graph::kSourceId)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -876,7 +881,7 @@ void FunctionalizeCond::CreateClusteredGraph() {
|
||||
for (const Node* in : node->in_nodes()) {
|
||||
ClusterHandle other_repr = Representative(in);
|
||||
// Skip source, sink and internal edges.
|
||||
if (!in->IsOp() || other_repr == repr) {
|
||||
if (other_repr == repr) {
|
||||
continue;
|
||||
}
|
||||
Cluster& cluster_node_in = clustered_graph_[other_repr];
|
||||
@ -887,7 +892,7 @@ void FunctionalizeCond::CreateClusteredGraph() {
|
||||
for (const Node* out : node->out_nodes()) {
|
||||
ClusterHandle other_repr = Representative(out);
|
||||
// Skip source, sink and internal edges.
|
||||
if (!out->IsOp() || other_repr == repr) {
|
||||
if (other_repr == repr) {
|
||||
continue;
|
||||
}
|
||||
Cluster& cluster_node_out = clustered_graph_[other_repr];
|
||||
@ -897,6 +902,7 @@ void FunctionalizeCond::CreateClusteredGraph() {
|
||||
}
|
||||
return cluster_node;
|
||||
};
|
||||
update_cluster_for_node(graph_->source_node());
|
||||
for (Node* node : switch_nodes_) {
|
||||
update_cluster_for_node(node).switch_nodes.insert(node);
|
||||
}
|
||||
@ -955,7 +961,7 @@ gtl::optional<FunctionalizeCond::Cluster*> FunctionalizeCond::GetSwitchCluster(
|
||||
for (Cluster* in : merge_cluster.in_nodes) {
|
||||
Cluster* cluster = in;
|
||||
if (in->switch_nodes.empty()) {
|
||||
if (in->in_nodes.size() != 1) {
|
||||
if (in->in_nodes.size() != 1 || in->out_nodes.size() != 1) {
|
||||
return gtl::nullopt;
|
||||
}
|
||||
// There is only a single `in` cluster.
|
||||
@ -1292,11 +1298,8 @@ std::vector<std::pair<int, FunctionalizeCond::Cluster*>>
|
||||
FunctionalizeCond::SortedMergeNodes() {
|
||||
VLOG(2) << "ProcessClusteredGraph";
|
||||
std::stack<std::pair<int, Cluster*>> stack;
|
||||
for (auto& c : clustered_graph_) {
|
||||
if (c.second.in_nodes.empty()) {
|
||||
stack.push({0, &c.second});
|
||||
}
|
||||
}
|
||||
// Initialize with the source node.
|
||||
stack.push({0, &clustered_graph_[ClusterHandle(Graph::kSourceId)]});
|
||||
|
||||
// Perform a depth-first traversal of the clustered graph computing the
|
||||
// switch-merge depth.
|
||||
|
@ -40,6 +40,11 @@ class ConstOp : public XlaOpKernel {
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
TensorShape shape(proto_.tensor_shape());
|
||||
|
||||
if (proto_.dtype() == DT_STRING) {
|
||||
LOG(WARNING) << "Not computing Const of type DT_STRING";
|
||||
ctx->SetInvalidOutput(0);
|
||||
return;
|
||||
}
|
||||
xla::ComputationBuilder* b = ctx->builder();
|
||||
|
||||
// To avoid blowups for large constants filled with the same value,
|
||||
|
@ -345,6 +345,16 @@ void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) {
|
||||
expression->set_constant_value(constant);
|
||||
}
|
||||
|
||||
void XlaOpKernelContext::SetInvalidOutput(int index) {
|
||||
const TensorShape shape;
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK(context_, context_->allocate_output(index, shape, &output));
|
||||
XlaExpression* expression = CastExpressionFromUninitializedTensor(output);
|
||||
xla::ComputationDataHandle handle;
|
||||
handle.set_handle(0);
|
||||
expression->set_handle(handle);
|
||||
}
|
||||
|
||||
void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) {
|
||||
Tensor* output = nullptr;
|
||||
// The shape of the output tensor is the shape of the resource itself
|
||||
|
@ -142,6 +142,10 @@ class XlaOpKernelContext {
|
||||
// SetConstantOutput where possible.
|
||||
void SetConstantOutput(int index, const Tensor& host_tensor);
|
||||
|
||||
// Sets output 'index' to an invalid value.
|
||||
// Any subsequent attempt to consume this output will cause an error.
|
||||
void SetInvalidOutput(int index);
|
||||
|
||||
// Status handling.
|
||||
void SetStatus(const Status& status) { context_->SetStatus(status); }
|
||||
Status status() { return context_->status(); }
|
||||
|
@ -265,6 +265,42 @@ bool BufferAssignment::SharesSliceAtIndex(
|
||||
GetUniqueSlice(hlo_b, shape_index_b).ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
bool BufferAssignment::HaveDisjointSlices(const HloInstruction* hlo_a,
|
||||
const HloInstruction* hlo_b) const {
|
||||
using SliceSet =
|
||||
FlatSet<BufferAllocation::Slice, BufferAllocation::Slice::Hasher>;
|
||||
// Gets the slices all of instr's subshapes. If any subshape doesn't have an
|
||||
// assigned slice, returns the empty set.
|
||||
auto collect_slices = [&](const HloInstruction* instr) -> SliceSet {
|
||||
SliceSet slices;
|
||||
Status status = ShapeUtil::ForEachSubshapeWithStatus(
|
||||
instr->shape(),
|
||||
[&](const Shape& /*subshape*/, const ShapeIndex& index) {
|
||||
auto shape_slices = GetAllSlices(instr, index);
|
||||
if (shape_slices.empty()) {
|
||||
return InvalidArgument("No slices assigned to part of instr.");
|
||||
}
|
||||
slices.insert(shape_slices.begin(), shape_slices.end());
|
||||
return Status::OK();
|
||||
});
|
||||
if (!status.ok()) {
|
||||
return {};
|
||||
}
|
||||
return slices;
|
||||
};
|
||||
|
||||
SliceSet slices_a = collect_slices(hlo_a);
|
||||
SliceSet slices_b = collect_slices(hlo_b);
|
||||
// hlo_a and hlo_b have disjoint slices if collect_slices succeeded (i.e.
|
||||
// didn't return the empty set) for both HLOs, and the two resulting sets of
|
||||
// slices are disjoint.
|
||||
return !slices_a.empty() && !slices_b.empty() &&
|
||||
std::none_of(slices_a.begin(), slices_a.end(),
|
||||
[&](const BufferAllocation::Slice& slice) {
|
||||
return slices_b.count(slice) > 0;
|
||||
});
|
||||
}
|
||||
|
||||
StatusOr<BufferAllocation::Slice>
|
||||
BufferAssignment::GetUniqueTopLevelOutputSlice() const {
|
||||
return GetUniqueTopLevelSlice(
|
||||
|
@ -327,6 +327,12 @@ class BufferAssignment {
|
||||
return SharesSliceAtIndex(hlo_a, {}, hlo_b, {});
|
||||
}
|
||||
|
||||
// Returns true if hlo_a and hlo_b both have at least one buffer assigned for
|
||||
// their top-level and each of their nested shape indices, and if hlo_a's
|
||||
// buffers are all different from hlo_b's buffers.
|
||||
bool HaveDisjointSlices(const HloInstruction* hlo_a,
|
||||
const HloInstruction* hlo_b) const;
|
||||
|
||||
// Returns the underlying points-to analysis used for this assignment.
|
||||
const TuplePointsToAnalysis& points_to_analysis() const {
|
||||
return liveness_->points_to_analysis();
|
||||
|
@ -27,14 +27,8 @@ namespace se = ::perftools::gputools;
|
||||
|
||||
namespace xla {
|
||||
|
||||
/* static */ tensorflow::mutex* Compiler::platform_compiler_mutex_;
|
||||
|
||||
/* static */ void Compiler::LazyInitMutex() {
|
||||
static std::once_flag mutex_init_flag;
|
||||
std::call_once(mutex_init_flag, []() {
|
||||
Compiler::platform_compiler_mutex_ = new tensorflow::mutex;
|
||||
});
|
||||
}
|
||||
/* static */ tensorflow::mutex Compiler::platform_compiler_mutex_(
|
||||
tensorflow::LINKER_INITIALIZED);
|
||||
|
||||
/* static */ std::map<perftools::gputools::Platform::Id,
|
||||
Compiler::CompilerFactory>*
|
||||
@ -55,8 +49,7 @@ Compiler::GetPlatformCompilers() {
|
||||
/* static */ void Compiler::RegisterCompilerFactory(
|
||||
se::Platform::Id platform_id,
|
||||
std::function<std::unique_ptr<Compiler>()> compiler_factory) {
|
||||
LazyInitMutex();
|
||||
tensorflow::mutex_lock lock(*platform_compiler_mutex_);
|
||||
tensorflow::mutex_lock lock(platform_compiler_mutex_);
|
||||
auto* factories = GetPlatformCompilerFactories();
|
||||
CHECK(factories->find(platform_id) == factories->end())
|
||||
<< "Compiler factory already registered for platform";
|
||||
@ -65,8 +58,7 @@ Compiler::GetPlatformCompilers() {
|
||||
|
||||
/* static */ StatusOr<Compiler*> Compiler::GetForPlatform(
|
||||
const se::Platform* platform) {
|
||||
LazyInitMutex();
|
||||
tensorflow::mutex_lock lock(*platform_compiler_mutex_);
|
||||
tensorflow::mutex_lock lock(platform_compiler_mutex_);
|
||||
|
||||
auto* compilers = GetPlatformCompilers();
|
||||
// See if we already instantiated a compiler for this platform.
|
||||
|
@ -157,8 +157,7 @@ class Compiler {
|
||||
|
||||
private:
|
||||
// Mutex that guards the platform-compiler map.
|
||||
static tensorflow::mutex* platform_compiler_mutex_;
|
||||
static void LazyInitMutex();
|
||||
static tensorflow::mutex platform_compiler_mutex_;
|
||||
|
||||
// Map from platform kind to compiler factory.
|
||||
static std::map<perftools::gputools::Platform::Id, CompilerFactory>*
|
||||
|
@ -94,7 +94,7 @@ StatusOr<DeviceAssignment> ComputationPlacer::AssignDevices(
|
||||
se::Platform::Id platform_id,
|
||||
ComputationPlacerCreationFunction creation_function) {
|
||||
tensorflow::mutex_lock lock(
|
||||
*ComputationPlacer::platform_computation_placer_mutex());
|
||||
ComputationPlacer::platform_computation_placer_mutex_);
|
||||
auto* computation_placers = GetPlatformComputationPlacers();
|
||||
CHECK(computation_placers->find(platform_id) == computation_placers->end());
|
||||
(*computation_placers)[platform_id].creation_function = creation_function;
|
||||
@ -103,7 +103,7 @@ StatusOr<DeviceAssignment> ComputationPlacer::AssignDevices(
|
||||
/* static */ StatusOr<ComputationPlacer*> ComputationPlacer::GetForPlatform(
|
||||
const se::Platform* platform) {
|
||||
tensorflow::mutex_lock lock(
|
||||
*ComputationPlacer::platform_computation_placer_mutex());
|
||||
ComputationPlacer::platform_computation_placer_mutex_);
|
||||
auto* computation_placers = GetPlatformComputationPlacers();
|
||||
|
||||
auto it = computation_placers->find(platform->id());
|
||||
@ -122,11 +122,9 @@ StatusOr<DeviceAssignment> ComputationPlacer::AssignDevices(
|
||||
return it->second.placer.get();
|
||||
}
|
||||
|
||||
/* static */ tensorflow::mutex*
|
||||
ComputationPlacer::platform_computation_placer_mutex() {
|
||||
static tensorflow::mutex* m = new tensorflow::mutex;
|
||||
return m;
|
||||
}
|
||||
/* static */ tensorflow::mutex
|
||||
ComputationPlacer::platform_computation_placer_mutex_(
|
||||
tensorflow::LINKER_INITIALIZED);
|
||||
|
||||
/* static */ std::map<perftools::gputools::Platform::Id,
|
||||
ComputationPlacer::State>*
|
||||
|
@ -89,11 +89,8 @@ class ComputationPlacer {
|
||||
const perftools::gputools::Platform* platform);
|
||||
|
||||
private:
|
||||
// Routine that returns the mutex that guards the platform-to-computation
|
||||
// placer map. Done as a routine to ensure correct initialization ordering,
|
||||
// since RegisterComputationPlacer can be called during program initialization
|
||||
// time.
|
||||
static tensorflow::mutex* platform_computation_placer_mutex();
|
||||
// The mutex that guards the platform-to-computation placer map.
|
||||
static tensorflow::mutex platform_computation_placer_mutex_;
|
||||
|
||||
// State kept for each kind of ComputationPlacer. Registration functions set
|
||||
// up creation_function, and then we use that to lazily create "placer" the
|
||||
|
@ -17,6 +17,7 @@ package_group(
|
||||
load(":build_defs.bzl", "runtime_copts")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
|
||||
load("//tensorflow/compiler/xla:xla.bzl", "ORC_JIT_MEMORY_MAPPER_TARGETS")
|
||||
|
||||
# Filegroup used to collect source files for dependency checking.
|
||||
filegroup(
|
||||
@ -83,6 +84,7 @@ cc_library(
|
||||
":cpu_options",
|
||||
":cpu_parallelization_preparation",
|
||||
":disassembler",
|
||||
":dot_op_emitter",
|
||||
":ir_emission_utils",
|
||||
":ir_emitter",
|
||||
":layout_assignment",
|
||||
@ -156,21 +158,23 @@ cc_library(
|
||||
":custom_call_target_registry",
|
||||
":disassembler",
|
||||
":external_constant_pool",
|
||||
":orc_jit_memory_mapper",
|
||||
":runtime_conv2d",
|
||||
":runtime_fork_join",
|
||||
":runtime_matmul",
|
||||
":runtime_single_threaded_conv2d",
|
||||
":runtime_single_threaded_matmul",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:lib",
|
||||
"@llvm//:core",
|
||||
"@llvm//:execution_engine",
|
||||
"@llvm//:core",
|
||||
"@llvm//:mc", # fixdeps: keep
|
||||
"@llvm//:orc_jit",
|
||||
"@llvm//:support",
|
||||
"@llvm//:target", # fixdeps: keep
|
||||
],
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
] + ORC_JIT_MEMORY_MAPPER_TARGETS,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
@ -282,7 +286,6 @@ cc_library(
|
||||
deps = [
|
||||
":cpu_options",
|
||||
":cpu_runtime",
|
||||
":ir_emission_utils",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
@ -619,6 +622,7 @@ cc_library(
|
||||
srcs = ["layout_assignment.cc"],
|
||||
hdrs = ["layout_assignment.h"],
|
||||
deps = [
|
||||
":dot_op_emitter",
|
||||
":ir_emission_utils",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/service:computation_layout",
|
||||
@ -706,6 +710,7 @@ cc_library(
|
||||
srcs = ["parallel_task_assignment.cc"],
|
||||
hdrs = ["parallel_task_assignment.h"],
|
||||
deps = [
|
||||
":dot_op_emitter",
|
||||
":ir_emission_utils",
|
||||
":shape_partition",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
@ -735,6 +740,16 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "orc_jit_memory_mapper",
|
||||
srcs = ["orc_jit_memory_mapper.cc"],
|
||||
hdrs = ["orc_jit_memory_mapper.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"@llvm//:execution_engine",
|
||||
],
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
filegroup(
|
||||
|
@ -54,6 +54,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/disassembler.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/ir_emitter.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/layout_assignment.h"
|
||||
|
@ -23,7 +23,6 @@ limitations under the License.
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/Value.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||
@ -950,5 +949,119 @@ llvm_ir::IrArray::Index DotOpEmitter::EmitOperandArrayLoopNest(
|
||||
return index;
|
||||
}
|
||||
|
||||
// Return whether the given shape is a matrix with no padding.
|
||||
static bool IsRank2WithNoPadding(const Shape& shape) {
|
||||
return ShapeUtil::Rank(shape) == 2 && !LayoutUtil::IsPadded(shape);
|
||||
}
|
||||
|
||||
// In a gemm operation where output = lhs * rhs, check whether the given shapes
|
||||
// are valid for the operation.
|
||||
static bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
|
||||
const Shape& output_shape) {
|
||||
// The inputs and the output must
|
||||
// 1) be matrices with no padding, and
|
||||
// 2) have an allowed element type.
|
||||
return output_shape.element_type() == F32 &&
|
||||
IsRank2WithNoPadding(lhs_shape) && IsRank2WithNoPadding(rhs_shape) &&
|
||||
IsRank2WithNoPadding(output_shape);
|
||||
}
|
||||
|
||||
bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) {
|
||||
// For certain types of Dot, we can call Eigen
|
||||
if (hlo.opcode() == HloOpcode::kDot) {
|
||||
const Shape& lhs_shape = hlo.operand(0)->shape();
|
||||
const Shape& rhs_shape = hlo.operand(1)->shape();
|
||||
|
||||
if (ShapeUtil::HasZeroElements(lhs_shape) ||
|
||||
ShapeUtil::HasZeroElements(rhs_shape)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ProfitableToImplementDotInUntiledLlvmIr(hlo) ==
|
||||
DotInLlvmIrProfitable::kYes ||
|
||||
ProfitableToImplementDotInTiledLlvmIr(hlo)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// If gemm can accept the operand shapes, use it rather than a custom
|
||||
// kernel.
|
||||
if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape())) {
|
||||
// The size of the reduction dimension should match. The shape inference
|
||||
// guarantees this invariant, so the check here is for programming
|
||||
// errors.
|
||||
CHECK_EQ(lhs_shape.dimensions(1), rhs_shape.dimensions(0));
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (hlo.opcode() == HloOpcode::kFusion &&
|
||||
hlo.fusion_kind() == HloInstruction::FusionKind::kTransposeDot &&
|
||||
hlo.fused_expression_root()->opcode() == HloOpcode::kDot) {
|
||||
auto* dot = hlo.fused_expression_root();
|
||||
const Shape& lhs_shape = dot->operand(0)->shape();
|
||||
const Shape& rhs_shape = dot->operand(1)->shape();
|
||||
if (ShapeUtil::HasZeroElements(lhs_shape) ||
|
||||
ShapeUtil::HasZeroElements(rhs_shape)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
DotInLlvmIrProfitable ProfitableToImplementDotInUntiledLlvmIr(
|
||||
const HloInstruction& dot) {
|
||||
if (dot.opcode() == HloOpcode::kDot && dot.shape().dimensions_size() == 2) {
|
||||
const Shape& result_shape = dot.shape();
|
||||
// kReductionDimensionThresholdBytes was chosen to be 1/4 of a typical L1
|
||||
// cache line size, so that we can have the reduction dimension of both the
|
||||
// LHS and RHS matrices and still have some space "left over". This needs
|
||||
// to be tuned further.
|
||||
const int64 kReductionDimensionThresholdBytes = 8 * 1024;
|
||||
const bool single_threaded_eigen =
|
||||
!dot.GetModule()->config().debug_options().xla_cpu_multi_thread_eigen();
|
||||
|
||||
// This is the point at which it is better to call into Eigen and shard the
|
||||
// dot across multiple worker threads. This is a rough estimate by running
|
||||
// a matmult benchmark on my local machine, and it can be tuned further.
|
||||
const int64 kMaxSingleThreadedFlops = 16 * 1024;
|
||||
|
||||
const int64 M = result_shape.dimensions(0);
|
||||
const int64 N = result_shape.dimensions(1);
|
||||
const int64 K = dot.operand(1)->shape().dimensions(0);
|
||||
const int64 primitive_type_size =
|
||||
ShapeUtil::ByteSizeOfPrimitiveType(result_shape.element_type());
|
||||
if (M == 1 &&
|
||||
K * primitive_type_size <= kReductionDimensionThresholdBytes &&
|
||||
(single_threaded_eigen || M * K * N <= kMaxSingleThreadedFlops)) {
|
||||
// Heuristics:
|
||||
//
|
||||
// - Look for a configuration where we will likely be able to keep LHS in
|
||||
// L1 and do a cache-optimal traversal of RHS.
|
||||
//
|
||||
// - Bail out on matrices that are large enough that Eigen can profitably
|
||||
// shard the computation across multiple cores. This only applies when
|
||||
// multi-threading is enabled.
|
||||
return LayoutUtil::IsMonotonicWithDim0Major(
|
||||
dot.operand(1)->shape().layout())
|
||||
? DotInLlvmIrProfitable::kWithColumnMajorRhs
|
||||
: DotInLlvmIrProfitable::kYes;
|
||||
}
|
||||
}
|
||||
return DotInLlvmIrProfitable::kNo;
|
||||
}
|
||||
|
||||
bool ProfitableToImplementDotInTiledLlvmIr(const HloInstruction& dot) {
|
||||
// Any Matrix-Vector product of floating point or integral type, or
|
||||
// a transpose-dot fusion of the same can be lowered to a tiled LLVM
|
||||
// IR implementation.
|
||||
const Shape& shape = dot.shape();
|
||||
return shape.dimensions_size() == 2 &&
|
||||
(shape.dimensions(0) == 1 || shape.dimensions(1) == 1) &&
|
||||
(primitive_util::IsFloatingPointType(shape.element_type()) ||
|
||||
primitive_util::IsIntegralType(shape.element_type()));
|
||||
}
|
||||
|
||||
} // namespace cpu
|
||||
} // namespace xla
|
||||
|
@ -30,6 +30,26 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace cpu {
|
||||
|
||||
bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo);
|
||||
|
||||
enum class DotInLlvmIrProfitable { kYes, kNo, kWithColumnMajorRhs };
|
||||
|
||||
// Returns a value to indicate if (and under what conditions) will lowering
|
||||
// |dot| as a untiled LLVM IR dot operation be profitable over calling into
|
||||
// Eigen or emitting a tiled LLVM IR implementation. Possible return values
|
||||
// are:
|
||||
//
|
||||
// * DotInLlvmIrProfitable::kYes - always profitable.
|
||||
// * DotInLlvmIrProfitable::kNo - never profitable.
|
||||
// * DotInLlvmIrProfitable::kWithColumnMajorRhs - only if we can manage to make
|
||||
// the Rhs layout column major.
|
||||
DotInLlvmIrProfitable ProfitableToImplementDotInUntiledLlvmIr(
|
||||
const HloInstruction& dot);
|
||||
|
||||
// Returns true to indicate that we can generate a tiled LLVM IR implementation
|
||||
// for |dot|.
|
||||
bool ProfitableToImplementDotInTiledLlvmIr(const HloInstruction& dot);
|
||||
|
||||
// Helper class for emitting LLVM IR to perform the dot operation.
|
||||
class DotOpEmitter {
|
||||
public:
|
||||
|
@ -74,122 +74,5 @@ bool PotentiallyImplementedAsEigenConvolution(
|
||||
kernel_shape.dimensions_size() - 1;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Return whether the given shape is a matrix with no padding.
|
||||
bool IsRank2WithNoPadding(const Shape& shape) {
|
||||
return ShapeUtil::Rank(shape) == 2 && !LayoutUtil::IsPadded(shape);
|
||||
}
|
||||
|
||||
// In a gemm operation where output = lhs * rhs, check whether the given shapes
|
||||
// are valid for the operation.
|
||||
bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
|
||||
const Shape& output_shape) {
|
||||
// The inputs and the output must
|
||||
// 1) be matrices with no padding, and
|
||||
// 2) have an allowed element type.
|
||||
return output_shape.element_type() == F32 &&
|
||||
IsRank2WithNoPadding(lhs_shape) && IsRank2WithNoPadding(rhs_shape) &&
|
||||
IsRank2WithNoPadding(output_shape);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) {
|
||||
// For certain types of Dot, we can call Eigen
|
||||
if (hlo.opcode() == HloOpcode::kDot) {
|
||||
const Shape& lhs_shape = hlo.operand(0)->shape();
|
||||
const Shape& rhs_shape = hlo.operand(1)->shape();
|
||||
|
||||
if (ShapeUtil::HasZeroElements(lhs_shape) ||
|
||||
ShapeUtil::HasZeroElements(rhs_shape)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ProfitableToImplementDotInUntiledLlvmIr(hlo) ==
|
||||
DotInLlvmIrProfitable::kYes ||
|
||||
ProfitableToImplementDotInTiledLlvmIr(hlo)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// If gemm can accept the operand shapes, use it rather than a custom
|
||||
// kernel.
|
||||
if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape())) {
|
||||
// The size of the reduction dimension should match. The shape inference
|
||||
// guarantees this invariant, so the check here is for programming
|
||||
// errors.
|
||||
CHECK_EQ(lhs_shape.dimensions(1), rhs_shape.dimensions(0));
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (hlo.opcode() == HloOpcode::kFusion &&
|
||||
hlo.fusion_kind() == HloInstruction::FusionKind::kTransposeDot &&
|
||||
hlo.fused_expression_root()->opcode() == HloOpcode::kDot) {
|
||||
auto* dot = hlo.fused_expression_root();
|
||||
const Shape& lhs_shape = dot->operand(0)->shape();
|
||||
const Shape& rhs_shape = dot->operand(1)->shape();
|
||||
if (ShapeUtil::HasZeroElements(lhs_shape) ||
|
||||
ShapeUtil::HasZeroElements(rhs_shape)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
DotInLlvmIrProfitable ProfitableToImplementDotInUntiledLlvmIr(
|
||||
const HloInstruction& dot) {
|
||||
if (dot.opcode() == HloOpcode::kDot && dot.shape().dimensions_size() == 2) {
|
||||
const Shape& result_shape = dot.shape();
|
||||
// kReductionDimensionThresholdBytes was chosen to be 1/4 of a typical L1
|
||||
// cache line size, so that we can have the reduction dimension of both the
|
||||
// LHS and RHS matrices and still have some space "left over". This needs
|
||||
// to be tuned further.
|
||||
const int64 kReductionDimensionThresholdBytes = 8 * 1024;
|
||||
const bool single_threaded_eigen =
|
||||
!dot.GetModule()->config().debug_options().xla_cpu_multi_thread_eigen();
|
||||
|
||||
// This is the point at which it is better to call into Eigen and shard the
|
||||
// dot across multiple worker threads. This is a rough estimate by running
|
||||
// a matmult benchmark on my local machine, and it can be tuned further.
|
||||
const int64 kMaxSingleThreadedFlops = 16 * 1024;
|
||||
|
||||
const int64 M = result_shape.dimensions(0);
|
||||
const int64 N = result_shape.dimensions(1);
|
||||
const int64 K = dot.operand(1)->shape().dimensions(0);
|
||||
const int64 primitive_type_size =
|
||||
ShapeUtil::ByteSizeOfPrimitiveType(result_shape.element_type());
|
||||
if (M == 1 &&
|
||||
K * primitive_type_size <= kReductionDimensionThresholdBytes &&
|
||||
(single_threaded_eigen || M * K * N <= kMaxSingleThreadedFlops)) {
|
||||
// Heuristics:
|
||||
//
|
||||
// - Look for a configuration where we will likely be able to keep LHS in
|
||||
// L1 and do a cache-optimal traversal of RHS.
|
||||
//
|
||||
// - Bail out on matrices that are large enough that Eigen can profitably
|
||||
// shard the computation across multiple cores. This only applies when
|
||||
// multi-threading is enabled.
|
||||
return LayoutUtil::IsMonotonicWithDim0Major(
|
||||
dot.operand(1)->shape().layout())
|
||||
? DotInLlvmIrProfitable::kWithColumnMajorRhs
|
||||
: DotInLlvmIrProfitable::kYes;
|
||||
}
|
||||
}
|
||||
return DotInLlvmIrProfitable::kNo;
|
||||
}
|
||||
|
||||
bool ProfitableToImplementDotInTiledLlvmIr(const HloInstruction& dot) {
|
||||
// Any Matrix-Vector product of floating point or integral type, or
|
||||
// a transpose-dot fusion of the same can be lowered to a tiled LLVM
|
||||
// IR implementation.
|
||||
const Shape& shape = dot.shape();
|
||||
return shape.dimensions_size() == 2 &&
|
||||
(shape.dimensions(0) == 1 || shape.dimensions(1) == 1) &&
|
||||
(primitive_util::IsFloatingPointType(shape.element_type()) ||
|
||||
primitive_util::IsIntegralType(shape.element_type()));
|
||||
}
|
||||
|
||||
} // namespace cpu
|
||||
} // namespace xla
|
||||
|
@ -23,27 +23,6 @@ namespace cpu {
|
||||
|
||||
bool PotentiallyImplementedAsEigenConvolution(
|
||||
const HloInstruction& convolution);
|
||||
|
||||
bool PotentiallyImplementedAsEigenDot(const HloInstruction& dot);
|
||||
|
||||
enum class DotInLlvmIrProfitable { kYes, kNo, kWithColumnMajorRhs };
|
||||
|
||||
// Returns a value to indicate if (and under what conditions) will lowering
|
||||
// |dot| as a untiled LLVM IR dot operation be profitable over calling into
|
||||
// Eigen or emitting a tiled LLVM IR implementation. Possible return values
|
||||
// are:
|
||||
//
|
||||
// * DotInLlvmIrProfitable::kYes - always profitable.
|
||||
// * DotInLlvmIrProfitable::kNo - never profitable.
|
||||
// * DotInLlvmIrProfitable::kWithColumnMajorRhs - only if we can manage to make
|
||||
// the Rhs layout column major.
|
||||
DotInLlvmIrProfitable ProfitableToImplementDotInUntiledLlvmIr(
|
||||
const HloInstruction& dot);
|
||||
|
||||
// Returns true to indicate that we can generate a tiled LLVM IR implementation
|
||||
// for |dot|.
|
||||
bool ProfitableToImplementDotInTiledLlvmIr(const HloInstruction& dot);
|
||||
|
||||
} // namespace cpu
|
||||
} // namespace xla
|
||||
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <numeric>
|
||||
|
||||
#include "tensorflow/compiler/xla/map_util.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
|
40
tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.cc
Normal file
40
tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.cc
Normal file
@ -0,0 +1,40 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
|
||||
namespace xla {
|
||||
namespace cpu {
|
||||
namespace orc_jit_memory_mapper {
|
||||
|
||||
static tensorflow::mutex mapper_instance_mutex(tensorflow::LINKER_INITIALIZED);
|
||||
static llvm::SectionMemoryManager::MemoryMapper* mapper_instance
|
||||
GUARDED_BY(mapper_instance_mutex) = nullptr;
|
||||
|
||||
llvm::SectionMemoryManager::MemoryMapper* GetInstance() {
|
||||
tensorflow::mutex_lock lock(mapper_instance_mutex);
|
||||
return mapper_instance;
|
||||
}
|
||||
|
||||
Registrar::Registrar(
|
||||
std::unique_ptr<llvm::SectionMemoryManager::MemoryMapper> mapper) {
|
||||
tensorflow::mutex_lock lock(mapper_instance_mutex);
|
||||
mapper_instance = mapper.release();
|
||||
}
|
||||
} // namespace orc_jit_memory_mapper
|
||||
} // namespace cpu
|
||||
} // namespace xla
|
56
tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h
Normal file
56
tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h
Normal file
@ -0,0 +1,56 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_ORC_JIT_MEMORY_MAPPER_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_ORC_JIT_MEMORY_MAPPER_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
|
||||
|
||||
namespace xla {
|
||||
namespace cpu {
|
||||
|
||||
namespace orc_jit_memory_mapper {
|
||||
// Returns the registered memory mapper if there is one. Returns nullptr if no
|
||||
// memory mapper is registered.
|
||||
llvm::SectionMemoryManager::MemoryMapper* GetInstance();
|
||||
|
||||
class Registrar {
|
||||
public:
|
||||
// Registers the `mapper` as a memory mapper. This is a no-op if `mapper` is
|
||||
// null. Precondition: no other memory mapper has been registered yet.
|
||||
explicit Registrar(
|
||||
std::unique_ptr<llvm::SectionMemoryManager::MemoryMapper> mapper);
|
||||
};
|
||||
} // namespace orc_jit_memory_mapper
|
||||
|
||||
#define XLA_INTERNAL_REGISTER_ORC_JIT_MEMORY_MAPPER(mapper_instance, ctr) \
|
||||
static ::xla::cpu::orc_jit_memory_mapper::Registrar \
|
||||
XLA_INTERNAL_REGISTER_ORC_JIT_MEMORY_MAPPER_NAME(ctr)(mapper_instance)
|
||||
|
||||
// __COUNTER__ must go through another macro to be properly expanded
|
||||
#define XLA_INTERNAL_REGISTER_ORC_JIT_MEMORY_MAPPER_NAME(ctr) \
|
||||
__orc_jit_memory_mapper_registrar_##ctr
|
||||
|
||||
// Registers the std::unique_ptr<llvm::SectionMemoryManager::MemoryMapper>
|
||||
// returned by the `factory` expression. `factory` is allowed to evaluate to
|
||||
// a null unique_ptr in which case this macro does nothing.
|
||||
#define XLA_REGISTER_ORC_JIT_MEMORY_MAPPER(factory) \
|
||||
XLA_INTERNAL_REGISTER_ORC_JIT_MEMORY_MAPPER(factory, __COUNTER__)
|
||||
} // namespace cpu
|
||||
} // namespace xla
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_ORC_JIT_MEMORY_MAPPER_H_
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
|
@ -32,6 +32,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h"
|
||||
@ -125,8 +126,10 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options,
|
||||
/*MAttrs=*/DetectMachineAttributes()))),
|
||||
disassembler_(*target_machine_),
|
||||
data_layout_(target_machine_->createDataLayout()),
|
||||
object_layer_(
|
||||
[] { return std::make_shared<llvm::SectionMemoryManager>(); }),
|
||||
object_layer_([] {
|
||||
return std::make_shared<llvm::SectionMemoryManager>(
|
||||
orc_jit_memory_mapper::GetInstance());
|
||||
}),
|
||||
compile_layer_(
|
||||
object_layer_,
|
||||
CompilerFunctor(target_machine_.get(), &disassembler_, opt_level,
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h"
|
||||
|
||||
#include <stdlib.h>
|
||||
#include <atomic>
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
|
||||
@ -258,7 +259,9 @@ StatusOr<std::vector<uint8>> CompilePtx(const string& ptx, int cc_major,
|
||||
return InternalError("couldn't get temp CUBIN file name");
|
||||
}
|
||||
auto cubin_cleaner = tensorflow::gtl::MakeCleanup([&cubin_path] {
|
||||
TF_CHECK_OK(tensorflow::Env::Default()->DeleteFile(cubin_path));
|
||||
// CUBIN file may never be created, so the failure to delete it should not
|
||||
// produce TF error.
|
||||
tensorflow::Env::Default()->DeleteFile(cubin_path).IgnoreError();
|
||||
});
|
||||
tensorflow::SubProcess ptxas_info_dumper;
|
||||
std::vector<string> ptxas_args = {ptxas_path, ptx_path, "-o", cubin_path,
|
||||
@ -500,10 +503,24 @@ std::vector<uint8> GpuCompiler::CompilePtxOrGetCachedResult(const string& ptx,
|
||||
VLOG(2) << "Compiled PTX size:" << ptx.size()
|
||||
<< " CUBIN size: " << cache_value->cubin_data.size();
|
||||
} else {
|
||||
LOG(WARNING)
|
||||
<< "Failed to compile ptx to cubin. Will attempt to let "
|
||||
"GPU driver compile the ptx. "
|
||||
<< maybe_cubin.status();
|
||||
bool log_warning = true;
|
||||
if (maybe_cubin.status().code() ==
|
||||
tensorflow::error::Code::NOT_FOUND) {
|
||||
// Missing ptxas is expected in some environments where CUDA SDK
|
||||
// binaries are not available. We don't want to spam logs with
|
||||
// identical warnings in this case.
|
||||
|
||||
// TODO(zhengxq): we should implement a LOG_FIRST_N and LOG_EVERY_N
|
||||
// for more general usage.
|
||||
static std::atomic<bool> warning_done(false);
|
||||
log_warning = !warning_done.exchange(true);
|
||||
}
|
||||
if (log_warning) {
|
||||
LOG(WARNING)
|
||||
<< "Failed to compile ptx to cubin. Will attempt to let "
|
||||
"GPU driver compile the ptx. "
|
||||
<< maybe_cubin.status();
|
||||
}
|
||||
}
|
||||
}
|
||||
cache_value->compilation_done = true;
|
||||
|
@ -166,11 +166,46 @@ void HloToIrBindings::BindHloToIrValue(const HloInstruction& hlo,
|
||||
*(base_ptrs_[&hlo].mutable_element(shape_index)) = typed_ir_value;
|
||||
}
|
||||
|
||||
// Determines whether hlo's buffers are never modified within the execution of
|
||||
// consumer.
|
||||
static bool BuffersInvariantWithinConsumer(
|
||||
const HloInstruction& hlo, const HloInstruction& consumer,
|
||||
const BufferAssignment* buffer_assignment) {
|
||||
// Check if consumer is inside a fusion node -- if so, "dereference" it until
|
||||
// we get to a non-fusion node.
|
||||
const HloInstruction* c = &consumer;
|
||||
while (c->IsFused()) {
|
||||
c = c->parent()->FusionInstruction();
|
||||
}
|
||||
|
||||
// If, after dereferencing c, we end up with a node that's not inside our
|
||||
// module's top-level computation (say our node is inside a while loop), we
|
||||
// give up on marking array as invariant, because this HLO may be run multiple
|
||||
// times (e.g. multiple while loop iterations, or multiple invocations of a
|
||||
// reducer's computation). TODO(jlebar): We could relax this constraint if we
|
||||
// emitted an llvm.invariant.group.barrier at the end of the computation.
|
||||
return c->parent() == c->GetModule()->entry_computation() &&
|
||||
buffer_assignment->HaveDisjointSlices(&hlo, &consumer);
|
||||
}
|
||||
|
||||
llvm_ir::IrArray HloToIrBindings::GetIrArray(const HloInstruction& hlo,
|
||||
const HloInstruction& consumer,
|
||||
const ShapeIndex& shape_index) {
|
||||
llvm_ir::IrArray ir_array(GetBasePointer(hlo, shape_index),
|
||||
ShapeUtil::GetSubshape(hlo.shape(), shape_index));
|
||||
alias_analysis_.AddAliasingInformationToIrArray(hlo, &ir_array);
|
||||
|
||||
// The GPU backend emits one kernel per top-level HLO, and LLVM views
|
||||
// execution of one kernel as the "whole program" executed on the GPU.
|
||||
// Therefore if hlo's output buffer is not modified within consumer, and if
|
||||
// consumer runs hlo only once (so that it doesn't create two different
|
||||
// outputs), then we can mark ir_array as invariant over the whole program.
|
||||
if (BuffersInvariantWithinConsumer(hlo, consumer, buffer_assignment_)) {
|
||||
VLOG(2) << "Marking " << hlo.name() << " as invariant within "
|
||||
<< consumer.name();
|
||||
ir_array.MarkInvariantOverWholeProgram(&module_->getContext());
|
||||
}
|
||||
|
||||
return ir_array;
|
||||
}
|
||||
|
||||
|
@ -76,8 +76,15 @@ class HloToIrBindings {
|
||||
return it->second.element(shape_index);
|
||||
}
|
||||
|
||||
// Return the underlying IrArray of the output of the given instruction.
|
||||
// Returns the IrArray which contains the output of hlo.
|
||||
//
|
||||
// consumer is the HLO in which this IrArray is used -- we use this to (try
|
||||
// to) add metadata indicating that the array is invariant within consumer.
|
||||
//
|
||||
// To get the buffer into which hlo should write its own output, call
|
||||
// GetIrArray(hlo, hlo).
|
||||
llvm_ir::IrArray GetIrArray(const HloInstruction& hlo,
|
||||
const HloInstruction& consumer,
|
||||
const ShapeIndex& shape_index = {});
|
||||
|
||||
private:
|
||||
|
@ -68,7 +68,8 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) {
|
||||
ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator;
|
||||
for (const HloInstruction* operand : hlo->operands()) {
|
||||
operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) {
|
||||
return GetIrArray(*operand).EmitReadArrayElement(index, &ir_builder_);
|
||||
return GetIrArray(*operand, *hlo)
|
||||
.EmitReadArrayElement(index, &ir_builder_);
|
||||
};
|
||||
}
|
||||
return EmitTargetElementLoop(
|
||||
@ -145,7 +146,8 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) {
|
||||
for (const HloInstruction* operand : tuple->operands()) {
|
||||
base_ptrs.push_back(GetBasePointer(*operand));
|
||||
}
|
||||
llvm_ir::EmitTuple(GetIrArray(*tuple), base_ptrs, &ir_builder_, module_);
|
||||
llvm_ir::EmitTuple(GetIrArray(*tuple, *tuple), base_ptrs, &ir_builder_,
|
||||
module_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -334,7 +336,8 @@ Status IrEmitter::HandleSelect(HloInstruction* select) {
|
||||
TF_RET_CHECK(pred->shape().element_type() == PRED);
|
||||
|
||||
if (ShapeUtil::IsTuple(select->shape())) {
|
||||
llvm_ir::EmitTupleSelect(GetIrArray(*select), GetIrArray(*pred),
|
||||
llvm_ir::EmitTupleSelect(GetIrArray(*select, *select),
|
||||
GetIrArray(*pred, *select),
|
||||
GetBasePointer(*on_true),
|
||||
GetBasePointer(*on_false), &ir_builder_, module_);
|
||||
return Status::OK();
|
||||
@ -349,9 +352,9 @@ Status IrEmitter::HandleSelect(HloInstruction* select) {
|
||||
Status IrEmitter::HandleDot(HloInstruction* dot) {
|
||||
auto lhs_instruction = dot->operand(0);
|
||||
auto rhs_instruction = dot->operand(1);
|
||||
const llvm_ir::IrArray& target_array = GetIrArray(*dot);
|
||||
const llvm_ir::IrArray& lhs_array = GetIrArray(*lhs_instruction);
|
||||
const llvm_ir::IrArray& rhs_array = GetIrArray(*rhs_instruction);
|
||||
const llvm_ir::IrArray& target_array = GetIrArray(*dot, *dot);
|
||||
const llvm_ir::IrArray& lhs_array = GetIrArray(*lhs_instruction, *dot);
|
||||
const llvm_ir::IrArray& rhs_array = GetIrArray(*rhs_instruction, *dot);
|
||||
|
||||
const Shape& lhs_shape = lhs_instruction->shape();
|
||||
const Shape& rhs_shape = rhs_instruction->shape();
|
||||
@ -571,7 +574,8 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) {
|
||||
|
||||
// Apply the reduction function to the loaded value.
|
||||
llvm::Value* input_address =
|
||||
GetIrArray(*arg).EmitArrayElementAddress(input_index, &ir_builder_);
|
||||
GetIrArray(*arg, *reduce)
|
||||
.EmitArrayElementAddress(input_index, &ir_builder_);
|
||||
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
|
||||
*function, {accumulator_addr, input_address}, accumulator_addr));
|
||||
|
||||
@ -587,7 +591,7 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) {
|
||||
|
||||
std::vector<llvm_ir::IrArray> parameter_arrays;
|
||||
for (HloInstruction* operand : fusion->operands()) {
|
||||
parameter_arrays.push_back(GetIrArray(*operand));
|
||||
parameter_arrays.push_back(GetIrArray(*operand, *fusion));
|
||||
}
|
||||
GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_,
|
||||
&ir_builder_, GetNestedComputer());
|
||||
@ -622,7 +626,8 @@ Status IrEmitter::HandleRng(HloInstruction* random) {
|
||||
ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator;
|
||||
for (const HloInstruction* operand : random->operands()) {
|
||||
operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) {
|
||||
return GetIrArray(*operand).EmitReadArrayElement(index, &ir_builder_);
|
||||
return GetIrArray(*operand, *random)
|
||||
.EmitReadArrayElement(index, &ir_builder_);
|
||||
};
|
||||
}
|
||||
// Emits a single-threaded loop because the loop body generated by the element
|
||||
@ -631,7 +636,7 @@ Status IrEmitter::HandleRng(HloInstruction* random) {
|
||||
GpuElementalIrEmitter(hlo_module_config_, module_, &ir_builder_,
|
||||
GetNestedComputer())
|
||||
.MakeElementGenerator(random, operand_to_generator),
|
||||
GetIrArray(*random), &ir_builder_)
|
||||
GetIrArray(*random, *random), &ir_builder_)
|
||||
.EmitLoop(IrName(random));
|
||||
}
|
||||
|
||||
|
@ -105,10 +105,16 @@ class IrEmitter : public DfsHloVisitorWithDefault {
|
||||
explicit IrEmitter(const HloModuleConfig& hlo_module_config,
|
||||
IrEmitterContext* ir_emitter_context, bool is_nested);
|
||||
|
||||
// A convenient helper for calling HloToIrBindings::GetIrArray.
|
||||
// Helper for calling HloToIrBindings::GetIrArray.
|
||||
//
|
||||
// Gets the IrArray which contains inst. This array has metadata that makes
|
||||
// it valid only within the IR that implements consumer. If you are
|
||||
// implementing an HLO and want to get its own output buffer, call
|
||||
// GetIrArray(hlo, hlo).
|
||||
llvm_ir::IrArray GetIrArray(const HloInstruction& inst,
|
||||
const HloInstruction& consumer,
|
||||
const ShapeIndex& shape_index = {}) {
|
||||
return bindings_.GetIrArray(inst, shape_index);
|
||||
return bindings_.GetIrArray(inst, consumer, shape_index);
|
||||
}
|
||||
// A convenient helper for calling HloToIrBindings::GetBasePointer.
|
||||
llvm::Value* GetBasePointer(const HloInstruction& inst) const {
|
||||
|
@ -115,7 +115,8 @@ Status IrEmitterNested::HandleParameter(HloInstruction* parameter) {
|
||||
Status IrEmitterNested::EmitTargetElementLoop(
|
||||
const HloInstruction& hlo,
|
||||
const llvm_ir::ElementGenerator& element_generator) {
|
||||
return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo), &ir_builder_)
|
||||
return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo, hlo),
|
||||
&ir_builder_)
|
||||
.EmitLoop();
|
||||
}
|
||||
|
||||
|
@ -282,7 +282,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
|
||||
MakeUnique<SequentialThunk>(std::move(thunks), fusion));
|
||||
std::vector<llvm_ir::IrArray> parameter_arrays;
|
||||
for (HloInstruction* operand : fusion->operands()) {
|
||||
parameter_arrays.push_back(GetIrArray(*operand));
|
||||
parameter_arrays.push_back(GetIrArray(*operand, *fusion));
|
||||
}
|
||||
GpuElementalIrEmitter elemental_emitter(
|
||||
hlo_module_config_, ir_emitter_context_->llvm_module(),
|
||||
@ -344,7 +344,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
|
||||
thunk_sequence_->emplace_back(BuildKernelThunk(fusion));
|
||||
std::vector<llvm_ir::IrArray> operand_arrays;
|
||||
for (HloInstruction* operand : fusion->operands()) {
|
||||
operand_arrays.push_back(GetIrArray(*operand));
|
||||
operand_arrays.push_back(GetIrArray(*operand, *fusion));
|
||||
}
|
||||
GpuElementalIrEmitter elemental_emitter(hlo_module_config_,
|
||||
ir_emitter_context_->llvm_module(),
|
||||
@ -355,7 +355,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
|
||||
|
||||
// Array to write into. Because this is an in-place operation, this is the
|
||||
// same as operand 0's array.
|
||||
llvm_ir::IrArray output_array = GetIrArray(*fusion);
|
||||
llvm_ir::IrArray output_array = GetIrArray(*fusion, *fusion);
|
||||
|
||||
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
|
||||
update_shape, ir_emitter_context_->device_description());
|
||||
@ -693,9 +693,10 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) {
|
||||
constexpr int64 tile_size = 32;
|
||||
constexpr int64 num_rows = 8;
|
||||
int64 num_tiles = EmitTranspose021Tiled(
|
||||
GetIrArray(*(copy->operand(0)))
|
||||
GetIrArray(*copy->operand(0), *copy)
|
||||
.CastToShape(reduced_input_shape, &ir_builder_),
|
||||
GetIrArray(*copy).CastToShape(reduced_output_shape, &ir_builder_),
|
||||
GetIrArray(*copy, *copy)
|
||||
.CastToShape(reduced_output_shape, &ir_builder_),
|
||||
tile_size, num_rows, &ir_builder_);
|
||||
UpdateLaunchDimensions(LaunchDimensions(num_tiles, num_rows * tile_size),
|
||||
LastThunk(), ir_emitter_context_->llvm_module());
|
||||
@ -850,9 +851,11 @@ Status IrEmitterUnnested::EmitColumnReduction(
|
||||
&ir_builder_);
|
||||
const HloInstruction* output =
|
||||
reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce;
|
||||
llvm::Value* output_address = GetIrArray(*output).EmitArrayElementAddress(
|
||||
llvm_ir::IrArray::Index(x, output->shape(), &ir_builder_), &ir_builder_,
|
||||
"output_element_address");
|
||||
llvm::Value* output_address =
|
||||
GetIrArray(*output, *output)
|
||||
.EmitArrayElementAddress(
|
||||
llvm_ir::IrArray::Index(x, output->shape(), &ir_builder_),
|
||||
&ir_builder_, "output_element_address");
|
||||
return EmitAtomicOperationForNestedComputation(
|
||||
*reducer, output_address, partial_reduction_result_address);
|
||||
};
|
||||
@ -1116,9 +1119,11 @@ Status IrEmitterUnnested::EmitRowReduction(
|
||||
"lane_id_is_zero", &ir_builder_);
|
||||
llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block,
|
||||
&ir_builder_);
|
||||
llvm::Value* output_address = GetIrArray(*output).EmitArrayElementAddress(
|
||||
llvm_ir::IrArray::Index(y, output->shape(), &ir_builder_), &ir_builder_,
|
||||
"output_element_address");
|
||||
llvm::Value* output_address =
|
||||
GetIrArray(*output, *output)
|
||||
.EmitArrayElementAddress(
|
||||
llvm_ir::IrArray::Index(y, output->shape(), &ir_builder_),
|
||||
&ir_builder_, "output_element_address");
|
||||
return EmitAtomicOperationForNestedComputation(
|
||||
*reducer, output_address, partial_reduction_result_address);
|
||||
};
|
||||
@ -1258,11 +1263,12 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) {
|
||||
MakeUnique<SequentialThunk>(std::move(thunks), reduce));
|
||||
return EmitReductionToVector(
|
||||
reduce, input->shape(),
|
||||
[this, input](const llvm_ir::IrArray::Index& index) {
|
||||
return GetIrArray(*input).EmitReadArrayElement(index, &ir_builder_);
|
||||
[&](const llvm_ir::IrArray::Index& index) {
|
||||
return GetIrArray(*input, *reduce)
|
||||
.EmitReadArrayElement(index, &ir_builder_);
|
||||
},
|
||||
[this, init_value](const llvm_ir::IrArray::Index& index) {
|
||||
return GetIrArray(*init_value)
|
||||
[&](const llvm_ir::IrArray::Index& index) {
|
||||
return GetIrArray(*init_value, *reduce)
|
||||
.EmitReadArrayElement(index, &ir_builder_);
|
||||
},
|
||||
dimensions_to_reduce, reducer);
|
||||
@ -1426,7 +1432,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
|
||||
ir_builder_.CreateStore(operand_index[i], selected_index_address_slot);
|
||||
}
|
||||
};
|
||||
llvm_ir::IrArray operand_array(GetIrArray(*operand));
|
||||
llvm_ir::IrArray operand_array = GetIrArray(*operand, *select_and_scatter);
|
||||
llvm::Value* operand_data =
|
||||
operand_array.EmitReadArrayElement(operand_index, &ir_builder_);
|
||||
ir_builder_.CreateStore(operand_data, selected_value_address);
|
||||
@ -1479,9 +1485,10 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
|
||||
ir_builder_.CreateLoad(selected_index_address_slot));
|
||||
}
|
||||
llvm::Value* source_value_address =
|
||||
GetIrArray(*source).EmitArrayElementAddress(source_index, &ir_builder_);
|
||||
GetIrArray(*source, *select_and_scatter)
|
||||
.EmitArrayElementAddress(source_index, &ir_builder_);
|
||||
llvm::Value* output_value_address =
|
||||
GetIrArray(*select_and_scatter)
|
||||
GetIrArray(*select_and_scatter, *select_and_scatter)
|
||||
.EmitArrayElementAddress(selected_index, &ir_builder_);
|
||||
return EmitAtomicOperationForNestedComputation(
|
||||
*select_and_scatter->scatter(), output_value_address,
|
||||
@ -1758,7 +1765,7 @@ Status IrEmitterUnnested::EmitInitializer(const HloInstruction* hlo,
|
||||
return EmitTargetElementLoopInThunk(
|
||||
*hlo,
|
||||
[=](const llvm_ir::IrArray::Index& index) {
|
||||
return GetIrArray(*init_value)
|
||||
return GetIrArray(*init_value, *hlo)
|
||||
.EmitReadArrayElement(index, &ir_builder_);
|
||||
},
|
||||
thunk);
|
||||
@ -1859,7 +1866,7 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
|
||||
UpdateLaunchDimensions(launch_dimensions, thunk,
|
||||
ir_emitter_context_->llvm_module());
|
||||
if (!hlo.IsMultiOutputFusion()) {
|
||||
return ParallelLoopEmitter(element_generator, GetIrArray(hlo),
|
||||
return ParallelLoopEmitter(element_generator, GetIrArray(hlo, hlo),
|
||||
launch_dimensions, &ir_builder_)
|
||||
.EmitLoop(IrName(&hlo));
|
||||
}
|
||||
@ -1867,7 +1874,7 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
|
||||
// For multiple outputs fusion, we need to emit each operand and the root.
|
||||
std::vector<llvm_ir::IrArray> output_arrays;
|
||||
for (int64 i = 0; i < ShapeUtil::TupleElementCount(hlo.shape()); ++i) {
|
||||
output_arrays.push_back(GetIrArray(hlo, {i}));
|
||||
output_arrays.push_back(GetIrArray(hlo, hlo, {i}));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, output_arrays,
|
||||
launch_dimensions, &ir_builder_)
|
||||
@ -1878,7 +1885,7 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
|
||||
tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
|
||||
}
|
||||
ir_builder_.SetInsertPoint(ir_builder_.GetInsertBlock()->getTerminator());
|
||||
llvm_ir::EmitTuple(GetIrArray(hlo), tuple_operand_ptrs, &ir_builder_,
|
||||
llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &ir_builder_,
|
||||
module_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -75,11 +75,43 @@ HloValue* HloDataflowAnalysis::NewHloValue(HloInstruction* instruction,
|
||||
std::forward_as_tuple(value_id, instruction, index, is_phi));
|
||||
CHECK(emplaced.second);
|
||||
|
||||
VLOG(4) << "NewHloValue = " << emplaced.first->second.ToShortString();
|
||||
|
||||
return &emplaced.first->second;
|
||||
}
|
||||
|
||||
void HloDataflowAnalysis::DeleteHloValue(HloValue::Id value_id) {
|
||||
values_.erase(value_id);
|
||||
void HloDataflowAnalysis::MarkValueForDeletion(HloValue::Id value_id) {
|
||||
HloValue& value = values_.at(value_id);
|
||||
VLOG(4) << "MarkValueForDeletion(" << value.ToShortString() << ")";
|
||||
|
||||
value_ids_to_delete_.push_back(value_id);
|
||||
}
|
||||
|
||||
void HloDataflowAnalysis::DeleteMarkedValues() {
|
||||
#ifndef NDEBUG
|
||||
// Verify that no marked-for-deletion values are in any of the value sets.
|
||||
tensorflow::gtl::FlatSet<HloValue::Id> id_set(value_ids_to_delete_.begin(),
|
||||
value_ids_to_delete_.end());
|
||||
for (const auto& pair : value_sets_) {
|
||||
const HloInstruction* instruction = pair.first;
|
||||
const InstructionValueSet& instruction_value_set = pair.second;
|
||||
for (const auto& index_value_set : instruction_value_set) {
|
||||
const HloValueSet& value_set = index_value_set.second;
|
||||
for (const HloValue* value : value_set.values()) {
|
||||
DCHECK(!ContainsKey(id_set, value->id()))
|
||||
<< "Value " << value->ToShortString()
|
||||
<< " marked for deletion, but still exists in value set for "
|
||||
"instruction "
|
||||
<< instruction->name();
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
for (HloValue::Id value_id : value_ids_to_delete_) {
|
||||
values_.erase(value_id);
|
||||
}
|
||||
value_ids_to_delete_.clear();
|
||||
}
|
||||
|
||||
string HloDataflowAnalysis::ToString() const {
|
||||
@ -121,6 +153,7 @@ bool HloDataflowAnalysis::Phi(
|
||||
HloInstruction* instruction,
|
||||
tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs) {
|
||||
CHECK(ssa_form_);
|
||||
VLOG(4) << "Phi(" << instruction->name() << ")";
|
||||
|
||||
for (const InstructionValueSet* input : inputs) {
|
||||
DCHECK(ShapeUtil::Compatible(instruction->shape(), input->shape()));
|
||||
@ -183,7 +216,7 @@ bool HloDataflowAnalysis::Phi(
|
||||
} else if (current_value != &new_value) {
|
||||
if (current_value_defined_here) {
|
||||
// Remove the existing phi.
|
||||
DeleteHloValue(current_value->id());
|
||||
MarkValueForDeletion(current_value->id());
|
||||
}
|
||||
value_set.Clear();
|
||||
value_set.AddValue(&new_value);
|
||||
@ -193,7 +226,8 @@ bool HloDataflowAnalysis::Phi(
|
||||
// Multiple distinct values reach this point. A phi value is
|
||||
// necessary.
|
||||
CHECK_GT(input_value_ids.size(), 1);
|
||||
if (current_value == nullptr || !current_value->is_phi()) {
|
||||
if (current_value == nullptr ||
|
||||
!(current_value->is_phi() && current_value_defined_here)) {
|
||||
value_set.Clear();
|
||||
value_set.AddValue(NewHloValue(instruction, index, /*is_phi=*/true));
|
||||
changed = true;
|
||||
@ -485,11 +519,13 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet(
|
||||
}
|
||||
}
|
||||
|
||||
void HloDataflowAnalysis::UpdateInstructionsAndPropagate(
|
||||
tensorflow::gtl::ArraySlice<HloInstruction*> instructions) {
|
||||
void HloDataflowAnalysis::Propagate() {
|
||||
std::queue<HloInstruction*> worklist;
|
||||
for (HloInstruction* instruction : instructions) {
|
||||
worklist.push(instruction);
|
||||
|
||||
for (HloComputation* computation : module_->computations()) {
|
||||
for (HloInstruction* instruction : computation->instructions()) {
|
||||
worklist.push(instruction);
|
||||
}
|
||||
}
|
||||
|
||||
while (!worklist.empty()) {
|
||||
@ -662,20 +698,17 @@ StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
|
||||
new HloDataflowAnalysis(module, ssa_form, bitcast_defines_value));
|
||||
|
||||
TF_RETURN_IF_ERROR(dataflow_analysis->InitializeInstructionValueSets());
|
||||
dataflow_analysis->Propagate();
|
||||
|
||||
// Construct list of all instructions to initialize the worklist to propagate
|
||||
// the data flow. For efficiency sort the instruction in post order so
|
||||
// producers appear before consumers.
|
||||
std::vector<HloInstruction*> all_instructions;
|
||||
for (const HloComputation* computation : module->MakeComputationPostOrder()) {
|
||||
for (HloInstruction* instruction :
|
||||
computation->MakeInstructionPostOrder()) {
|
||||
all_instructions.push_back(instruction);
|
||||
}
|
||||
}
|
||||
dataflow_analysis->UpdateInstructionsAndPropagate(all_instructions);
|
||||
// Delete all values marked for deletion.
|
||||
dataflow_analysis->DeleteMarkedValues();
|
||||
|
||||
// Add in positions to all values.
|
||||
// Gather and set all non-definition positions of all values. Value deletion
|
||||
// is rare, so just use a vector indexed by Value::Id rather than a map from
|
||||
// Value::Id to positions. There should be very few holes in the vector, and
|
||||
// lookup is faster.
|
||||
std::vector<std::vector<HloPosition>> value_positions(
|
||||
dataflow_analysis->next_value_id_);
|
||||
for (const HloComputation* computation : module->computations()) {
|
||||
for (HloInstruction* instruction : computation->instructions()) {
|
||||
for (const auto& pair :
|
||||
@ -684,13 +717,18 @@ StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
|
||||
const HloValueSet& value_set = pair.second;
|
||||
for (const HloValue* value : value_set.values()) {
|
||||
if (value->defining_instruction() != instruction) {
|
||||
dataflow_analysis->GetValue(value->id())
|
||||
.AddPosition(instruction, index);
|
||||
value_positions[value->id()].push_back(
|
||||
HloPosition{instruction, index});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto& pair : dataflow_analysis->values_) {
|
||||
HloValue::Id value_id = pair.first;
|
||||
HloValue& value = pair.second;
|
||||
value.SetPositionsAndComputeUses(value_positions[value_id]);
|
||||
}
|
||||
|
||||
// Construct vector of values.
|
||||
dataflow_analysis->values_vector_.reserve(dataflow_analysis->values_.size());
|
||||
|
@ -126,13 +126,16 @@ class HloDataflowAnalysis {
|
||||
HloValue* NewHloValue(HloInstruction* instruction, const ShapeIndex& index,
|
||||
bool is_phi = false);
|
||||
|
||||
// Delete the HloValue with the given ID.
|
||||
void DeleteHloValue(HloValue::Id value_id);
|
||||
// Mark the HloValue with the given ID for deletion.
|
||||
void MarkValueForDeletion(HloValue::Id value_id);
|
||||
|
||||
// Delete all HloValues marked for deletion. Should be called after
|
||||
// propagation is complete.
|
||||
void DeleteMarkedValues();
|
||||
|
||||
// Constructs and initializes the InstructionValueSets of all instructions to
|
||||
// contain exactly the HloValues defined by each instruction. These values can
|
||||
// then propagated throughout the HLO graph by calling
|
||||
// UpdateInstructionsAndPropagate.
|
||||
// then propagated throughout the HLO graph by calling Propagate.
|
||||
Status InitializeInstructionValueSets();
|
||||
|
||||
// Updates the value set of the given instruction based on the values flowing
|
||||
@ -152,10 +155,8 @@ class HloDataflowAnalysis {
|
||||
bool UpdateTupleValueSet(HloInstruction* tuple);
|
||||
bool UpdateWhileValueSet(HloInstruction* xla_while);
|
||||
|
||||
// Update the value sets of the given instructions and propagate the
|
||||
// changes to fixed point.
|
||||
void UpdateInstructionsAndPropagate(
|
||||
tensorflow::gtl::ArraySlice<HloInstruction*> instructions);
|
||||
// Propagate the dataflow through the module.
|
||||
void Propagate();
|
||||
|
||||
// Return the result of the SSA Phi function applied to the given inputs at
|
||||
// the given instruction. If skip_top_level is true, then the top level of the
|
||||
@ -191,6 +192,11 @@ class HloDataflowAnalysis {
|
||||
// A map from instruction to InstructionValueSet.
|
||||
std::unordered_map<const HloInstruction*, InstructionValueSet> value_sets_;
|
||||
|
||||
// Values marked for deletion during construction. We don't delete them
|
||||
// immediately because references to them may remain in ValueSets temporarily
|
||||
// during propagation. After construction, these values are deleted.
|
||||
std::vector<HloValue::Id> value_ids_to_delete_;
|
||||
|
||||
// A vector containing all HloValues sorted by HloValue::Id.
|
||||
std::vector<const HloValue*> values_vector_;
|
||||
|
||||
|
@ -211,10 +211,10 @@ TEST_P(HloDataflowAnalysisTest, NestedTuple) {
|
||||
HloPosition{nested_tuple, {0, 0}}, HloPosition{nested_tuple, {1, 0}},
|
||||
HloPosition{nested_tuple, {2}}, HloPosition{gte_tuple, {0}},
|
||||
HloPosition{gte_out, {}}));
|
||||
// Constant values should have no uses though one is live out. The positions
|
||||
// where they appear as operands are on instructions which do not use the
|
||||
// values (eg, Tuple).
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).uses().empty());
|
||||
// Constant values should have only a single use, which is the root of the
|
||||
// computation.
|
||||
EXPECT_THAT(analysis.GetValueDefinedAt(constant1, /*index=*/{}).uses(),
|
||||
UnorderedElementsAre(HloUse{gte_out, 0, {0}}));
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).uses().empty());
|
||||
|
||||
// The top-level tuple values are used in GTE instructions.
|
||||
@ -274,12 +274,11 @@ TEST_P(HloDataflowAnalysisTest, SingleCall) {
|
||||
EXPECT_EQ(analysis.GetUniqueValueAt(call), analysis.GetValueDefinedAt(add));
|
||||
|
||||
EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
|
||||
UnorderedElementsAre(HloUse{add, 0, {}}));
|
||||
UnorderedElementsAre(HloUse{call, 0, {}}, HloUse{add, 0, {}}));
|
||||
EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(),
|
||||
UnorderedElementsAre(HloUse{add, 1, {}}));
|
||||
UnorderedElementsAre(HloUse{call, 1, {}}, HloUse{add, 1, {}}));
|
||||
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation());
|
||||
}
|
||||
|
||||
TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithSameArguments) {
|
||||
@ -323,18 +322,17 @@ TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithSameArguments) {
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(sub));
|
||||
|
||||
EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
|
||||
UnorderedElementsAre(HloUse{add, 0, {}}));
|
||||
UnorderedElementsAre(HloUse{call1, 0, {}}, HloUse{call2, 0, {}},
|
||||
HloUse{add, 0, {}}));
|
||||
EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(),
|
||||
UnorderedElementsAre(HloUse{add, 1, {}}));
|
||||
UnorderedElementsAre(HloUse{call1, 1, {}}, HloUse{call2, 1, {}},
|
||||
HloUse{add, 1, {}}));
|
||||
// The Add from the subcomputation is used as both operands of the Subtract.
|
||||
EXPECT_THAT(analysis.GetValueDefinedAt(add).uses(),
|
||||
UnorderedElementsAre(HloUse{sub, 0, {}}, HloUse{sub, 1, {}}));
|
||||
|
||||
EXPECT_FALSE(analysis.GetValueDefinedAt(add).live_out_of_module());
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation());
|
||||
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(sub).live_out_of_module());
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(sub).live_out_of_computation());
|
||||
}
|
||||
|
||||
TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithDifferentArguments) {
|
||||
@ -408,7 +406,7 @@ TEST_P(HloDataflowAnalysisTest, NestedCalls) {
|
||||
auto outer_param1 = outer_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
|
||||
// Swizzle parameters.
|
||||
outer_builder.AddInstruction(HloInstruction::CreateCall(
|
||||
auto nested_call = outer_builder.AddInstruction(HloInstruction::CreateCall(
|
||||
scalar_shape_, {outer_param1, outer_param0}, inner_computation));
|
||||
HloComputation* outer_computation =
|
||||
module_->AddEmbeddedComputation(outer_builder.Build());
|
||||
@ -418,7 +416,7 @@ TEST_P(HloDataflowAnalysisTest, NestedCalls) {
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
|
||||
auto constant2 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
|
||||
builder.AddInstruction(HloInstruction::CreateCall(
|
||||
auto call = builder.AddInstruction(HloInstruction::CreateCall(
|
||||
scalar_shape_, {constant1, constant2}, outer_computation));
|
||||
module_->AddEntryComputation(builder.Build());
|
||||
|
||||
@ -431,10 +429,14 @@ TEST_P(HloDataflowAnalysisTest, NestedCalls) {
|
||||
|
||||
// Verify that the uses of the constants are properly swizzled by parameter
|
||||
// permutation in nested_call.
|
||||
EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
|
||||
UnorderedElementsAre(HloUse{add, 1, {}}));
|
||||
EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(),
|
||||
UnorderedElementsAre(HloUse{add, 0, {}}));
|
||||
EXPECT_THAT(
|
||||
analysis.GetValueDefinedAt(constant1).uses(),
|
||||
UnorderedElementsAre(HloUse{call, 0, {}}, HloUse{nested_call, 1, {}},
|
||||
HloUse{add, 1, {}}));
|
||||
EXPECT_THAT(
|
||||
analysis.GetValueDefinedAt(constant2).uses(),
|
||||
UnorderedElementsAre(HloUse{call, 1, {}}, HloUse{nested_call, 0, {}},
|
||||
HloUse{add, 0, {}}));
|
||||
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
|
||||
}
|
||||
@ -469,7 +471,7 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) {
|
||||
HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
|
||||
auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1));
|
||||
body_builder.AddInstruction(
|
||||
auto body_root = body_builder.AddInstruction(
|
||||
HloInstruction::CreateTuple({body_element_0, add}));
|
||||
HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
|
||||
|
||||
@ -496,8 +498,6 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) {
|
||||
bool ssa_form = GetParam();
|
||||
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
|
||||
|
||||
EXPECT_TRUE(
|
||||
analysis.GetValueDefinedAt(cond_constant).live_out_of_computation());
|
||||
EXPECT_FALSE(analysis.GetValueDefinedAt(cond_constant).live_out_of_module());
|
||||
|
||||
if (ssa_form) {
|
||||
@ -517,14 +517,14 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) {
|
||||
|
||||
EXPECT_THAT(
|
||||
analysis.GetValueDefinedAt(constant1).uses(),
|
||||
UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{xla_while, 0, {0}}));
|
||||
UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{body_root, 0, {}},
|
||||
HloUse{xla_while, 0, {0}}));
|
||||
|
||||
// Constant1 passes through the body and out of the module.
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1})
|
||||
.live_out_of_module());
|
||||
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation());
|
||||
EXPECT_FALSE(analysis.GetValueDefinedAt(add).live_out_of_module());
|
||||
} else {
|
||||
// While instruction and subcomputation parameters should not define values
|
||||
@ -538,7 +538,6 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) {
|
||||
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation());
|
||||
}
|
||||
}
|
||||
|
||||
@ -915,9 +914,11 @@ TEST_P(HloDataflowAnalysisTest, TupleSelect) {
|
||||
HloUse{select12, 1, {}}));
|
||||
|
||||
// The two constant values just pass through the Selects and are not
|
||||
// used. They are live out however.
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).uses().empty());
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).uses().empty());
|
||||
// used except at the root. They are live out however.
|
||||
EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
|
||||
UnorderedElementsAre(HloUse{select1234, 1, {0}}));
|
||||
EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(),
|
||||
UnorderedElementsAre(HloUse{select1234, 1, {0}}));
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
|
||||
}
|
||||
@ -1318,7 +1319,7 @@ TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) {
|
||||
|
||||
auto entry = module_->AddEntryComputation(builder.Build());
|
||||
bool ssa_form = GetParam();
|
||||
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
|
||||
RunAnalysis(ssa_form);
|
||||
|
||||
SequentialHloOrdering::HloModuleSequence sequence;
|
||||
sequence.insert({entry, {param, xla_while}});
|
||||
@ -1329,12 +1330,6 @@ TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) {
|
||||
|
||||
SequentialHloOrdering ordering(module_.get(), sequence);
|
||||
|
||||
// 'add' is the body root even though later instructions follow in the order
|
||||
// like 'dead_negate'. Only 'add' should be live out of the computation.
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation());
|
||||
EXPECT_FALSE(
|
||||
analysis.GetValueDefinedAt(dead_negate).live_out_of_computation());
|
||||
|
||||
// 'add' is live out of the body and will interfere with an later instructions
|
||||
// such as 'dead_constant' and 'dead_negate'.
|
||||
EXPECT_TRUE(InstructionsMayInterfere(ordering, add, dead_constant));
|
||||
|
@ -26,7 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
|
||||
namespace xla {
|
||||
HloToProfileIndex::HloToProfileIndex(const HloModule& module) {
|
||||
HloProfileIndexMap::HloProfileIndexMap(const HloModule& module) {
|
||||
size_t current_profile_index = 0;
|
||||
for (xla::HloComputation* computation : module.MakeComputationPostOrder()) {
|
||||
InsertOrDie(&computation_to_profile_idx_, computation,
|
||||
@ -41,24 +41,24 @@ HloToProfileIndex::HloToProfileIndex(const HloModule& module) {
|
||||
}
|
||||
|
||||
static HloProfilePrinter CreateOwnedHloProfilePrinter(
|
||||
const HloToProfileIndex& hlo_to_profile_index,
|
||||
const HloProfileIndexMap& hlo_profile_index_map,
|
||||
const HloCostAnalysis& cost_analysis) {
|
||||
using HloComputationInfo = HloProfilePrinter::HloComputationInfo;
|
||||
using HloInstructionInfo = HloProfilePrinter::HloInstructionInfo;
|
||||
|
||||
HloComputationInfo* computation_infos =
|
||||
new HloComputationInfo[hlo_to_profile_index.computation_count()];
|
||||
new HloComputationInfo[hlo_profile_index_map.computation_count()];
|
||||
|
||||
// There are two "indices" in play here. The first one is the index of the
|
||||
// HloComputationInfo or HloInstructionInfo in the array that contains said
|
||||
// HloComputationInfo or HloInstructionInfo. The second index is the index of
|
||||
// the HloComputationInfo or HloInstructionInfo in the profile counters array,
|
||||
// as decided by hlo_to_profile_index. The latter index is always referred to
|
||||
// as "profile_index".
|
||||
// as decided by hlo_profile_index_map. The latter index is always referred
|
||||
// to as "profile_index".
|
||||
|
||||
size_t computation_index_in_static_data = 0;
|
||||
size_t max_profile_index = hlo_to_profile_index.total_count();
|
||||
for (const auto& pair : hlo_to_profile_index.computation_to_profile_idx()) {
|
||||
size_t max_profile_index = hlo_profile_index_map.total_count();
|
||||
for (const auto& pair : hlo_profile_index_map.computation_to_profile_idx()) {
|
||||
CHECK_LT(pair.second, max_profile_index);
|
||||
const HloComputation* computation = pair.first;
|
||||
size_t current_computation_index = computation_index_in_static_data++;
|
||||
@ -85,7 +85,7 @@ static HloProfilePrinter CreateOwnedHloProfilePrinter(
|
||||
instruction_info->bytes_accessed = cost_analysis.bytes_accessed(*hlo);
|
||||
instruction_info->seconds = cost_analysis.seconds(*hlo);
|
||||
instruction_info->profile_index =
|
||||
hlo_to_profile_index.GetProfileIndexFor(*hlo);
|
||||
hlo_profile_index_map.GetProfileIndexFor(*hlo);
|
||||
CHECK_LT(instruction_info->profile_index, max_profile_index);
|
||||
}
|
||||
}
|
||||
@ -109,26 +109,26 @@ static HloProfilePrinter CreateOwnedHloProfilePrinter(
|
||||
};
|
||||
|
||||
return HloProfilePrinter(computation_infos,
|
||||
hlo_to_profile_index.computation_count(), deleter);
|
||||
hlo_profile_index_map.computation_count(), deleter);
|
||||
}
|
||||
|
||||
HloExecutionProfile::HloExecutionProfile(const HloModule& module,
|
||||
const HloCostAnalysis& cost_analysis)
|
||||
: hlo_to_profile_index_(module),
|
||||
: hlo_profile_index_map_(module),
|
||||
hlo_profile_printer_(
|
||||
CreateOwnedHloProfilePrinter(hlo_to_profile_index_, cost_analysis)),
|
||||
CreateOwnedHloProfilePrinter(hlo_profile_index_map_, cost_analysis)),
|
||||
profile_counters_(
|
||||
/*count*/ hlo_to_profile_index_.total_count(),
|
||||
/*count*/ hlo_profile_index_map_.total_count(),
|
||||
/*value*/ 0) {}
|
||||
|
||||
void HloExecutionProfile::SetCyclesTakenBy(const HloInstruction* hlo,
|
||||
uint64 cycles_taken) {
|
||||
profile_counters_[hlo_to_profile_index_.GetProfileIndexFor(*hlo)] =
|
||||
profile_counters_[hlo_profile_index_map_.GetProfileIndexFor(*hlo)] =
|
||||
cycles_taken;
|
||||
}
|
||||
|
||||
uint64 HloExecutionProfile::GetCyclesTakenBy(const HloInstruction& hlo) const {
|
||||
return profile_counters_[hlo_to_profile_index_.GetProfileIndexFor(hlo)];
|
||||
return profile_counters_[hlo_profile_index_map_.GetProfileIndexFor(hlo)];
|
||||
}
|
||||
|
||||
string HloExecutionProfile::ToString(
|
||||
|
@ -29,18 +29,18 @@ namespace xla {
|
||||
|
||||
class HloInstruction;
|
||||
|
||||
// Maps all HloInstructions and HloComputions in an HloModule to integers.
|
||||
// These integers form the contiguous range [0, GetTotalCount()).
|
||||
class HloToProfileIndex {
|
||||
// Maps all HloInstructions and HloComputations in an HloModule to integers.
|
||||
// These integers form the contiguous range [0, total_count()).
|
||||
class HloProfileIndexMap {
|
||||
public:
|
||||
// Scans `module` to populate this instance of HloToProfileIndex.
|
||||
explicit HloToProfileIndex(const HloModule& module);
|
||||
// Scans `module` to populate this instance of HloProfileIndexMap.
|
||||
explicit HloProfileIndexMap(const HloModule& module);
|
||||
|
||||
HloToProfileIndex(const HloToProfileIndex&) = default;
|
||||
HloToProfileIndex(HloToProfileIndex&&) = default;
|
||||
HloProfileIndexMap(const HloProfileIndexMap&) = default;
|
||||
HloProfileIndexMap(HloProfileIndexMap&&) = default;
|
||||
|
||||
HloToProfileIndex& operator=(const HloToProfileIndex&) = default;
|
||||
HloToProfileIndex& operator=(HloToProfileIndex&&) = default;
|
||||
HloProfileIndexMap& operator=(const HloProfileIndexMap&) = default;
|
||||
HloProfileIndexMap& operator=(HloProfileIndexMap&&) = default;
|
||||
|
||||
size_t GetProfileIndexFor(const HloInstruction& instruction) const {
|
||||
return FindOrDie(instruction_to_profile_idx(), &instruction);
|
||||
@ -97,14 +97,14 @@ class HloExecutionProfile {
|
||||
|
||||
// Return the number of cycles this computation took to execute.
|
||||
uint64 total_cycles_executed(const HloComputation& computation) const {
|
||||
return profile_counters_[hlo_to_profile_index_.GetProfileIndexFor(
|
||||
return profile_counters_[hlo_profile_index_map_.GetProfileIndexFor(
|
||||
computation)];
|
||||
}
|
||||
|
||||
// Record how many cycles a computation took to execute.
|
||||
void set_total_cycles_executed(const HloComputation& computation,
|
||||
uint64 total_cycles_executed) {
|
||||
profile_counters_[hlo_to_profile_index_.GetProfileIndexFor(computation)] =
|
||||
profile_counters_[hlo_profile_index_map_.GetProfileIndexFor(computation)] =
|
||||
total_cycles_executed;
|
||||
}
|
||||
|
||||
@ -117,9 +117,9 @@ class HloExecutionProfile {
|
||||
string ToString(const DeviceDescription& device_description) const;
|
||||
|
||||
private:
|
||||
// hlo_to_profile_index_ maps an Hlo entity (computation or instruction) to an
|
||||
// index in profile_counters_.
|
||||
HloToProfileIndex hlo_to_profile_index_;
|
||||
// hlo_profile_index_map_ maps an Hlo entity (computation or instruction) to
|
||||
// an index in profile_counters_.
|
||||
HloProfileIndexMap hlo_profile_index_map_;
|
||||
|
||||
// Used to print profile_counters_ in a human readable form.
|
||||
HloProfilePrinter hlo_profile_printer_;
|
||||
|
@ -970,6 +970,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
|
||||
case HloOpcode::kOutfeed:
|
||||
case HloOpcode::kCrossReplicaSum:
|
||||
return kBrown;
|
||||
case HloOpcode::kConditional:
|
||||
case HloOpcode::kCustomCall:
|
||||
case HloOpcode::kWhile:
|
||||
case HloOpcode::kCall:
|
||||
@ -1003,7 +1004,7 @@ string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) {
|
||||
}
|
||||
string extended_opcode =
|
||||
StrCat(HloOpcodeString(instr->opcode()),
|
||||
instr->opcode() == HloOpcode::kFusion
|
||||
instr->opcode() != HloOpcode::kFusion
|
||||
? ""
|
||||
: StrCat(":", xla::ToString(instr->fusion_kind())));
|
||||
// If the name does not contain the opcode, render both.
|
||||
|
@ -43,6 +43,7 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
using tensorflow::str_util::CEscape;
|
||||
using ::tensorflow::str_util::Join;
|
||||
using ::tensorflow::strings::StrAppend;
|
||||
using ::tensorflow::strings::StrCat;
|
||||
@ -1209,6 +1210,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
|
||||
new_operands[2], new_operands[3],
|
||||
new_operands[4], epsilon(), feature_index());
|
||||
break;
|
||||
case HloOpcode::kConditional:
|
||||
case HloOpcode::kRecv:
|
||||
case HloOpcode::kRecvDone:
|
||||
case HloOpcode::kSend:
|
||||
@ -1602,6 +1604,7 @@ bool HloInstruction::IdenticalSlowPath(
|
||||
return dimensions() == other.dimensions();
|
||||
|
||||
// These opcodes are not yet supported.
|
||||
case HloOpcode::kConditional:
|
||||
case HloOpcode::kInfeed:
|
||||
case HloOpcode::kOutfeed:
|
||||
case HloOpcode::kSort:
|
||||
@ -1965,6 +1968,13 @@ std::vector<string> HloInstruction::ExtraAttributesToString() const {
|
||||
}),
|
||||
"}"));
|
||||
}
|
||||
if (opcode() == HloOpcode::kInfeed && !infeed_config_.empty()) {
|
||||
extra.push_back(StrCat("infeed_config=\"", CEscape(infeed_config_), "\""));
|
||||
}
|
||||
if (opcode() == HloOpcode::kOutfeed && !outfeed_config_.empty()) {
|
||||
extra.push_back(
|
||||
StrCat("outfeed_config=\"", CEscape(outfeed_config_), "\""));
|
||||
}
|
||||
return extra;
|
||||
}
|
||||
|
||||
@ -2347,6 +2357,7 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
|
||||
return visitor->HandleSendDone(this);
|
||||
|
||||
// These opcodes are not handled here.
|
||||
case HloOpcode::kConditional:
|
||||
case HloOpcode::kTrace:
|
||||
break;
|
||||
}
|
||||
@ -2920,7 +2931,6 @@ string PaddingConfigToString(const PaddingConfig& padding) {
|
||||
|
||||
string OpMetadataToString(const OpMetadata& metadata) {
|
||||
std::vector<string> result;
|
||||
using tensorflow::str_util::CEscape;
|
||||
if (!metadata.op_type().empty()) {
|
||||
result.push_back(StrCat("op_type=\"", CEscape(metadata.op_type()), "\""));
|
||||
}
|
||||
|
@ -58,6 +58,7 @@ namespace xla {
|
||||
V(kClamp, "clamp") \
|
||||
V(kComplex, "complex") \
|
||||
V(kConcatenate, "concatenate", kHloOpcodeIsVariadic) \
|
||||
V(kConditional, "conditional") \
|
||||
V(kConstant, "constant") \
|
||||
V(kConvert, "convert") \
|
||||
V(kConvolution, "convolution") \
|
||||
|
@ -173,6 +173,19 @@ bool HloOrdering::UseIsBeforeValueDefinition(
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// The use at a call occurs before values that are defined in the called
|
||||
// computation.
|
||||
if (use.instruction->opcode() == HloOpcode::kCall) {
|
||||
const HloInstruction* call = use.instruction;
|
||||
if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
|
||||
call->to_apply())) {
|
||||
VLOG(4) << " use is call " << use.instruction->name()
|
||||
<< " and def is in called computation";
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
VLOG(4) << " use is not before value";
|
||||
return false;
|
||||
}
|
||||
@ -187,23 +200,6 @@ bool HloOrdering::LiveRangeStrictlyBefore(
|
||||
return false;
|
||||
}
|
||||
|
||||
// Live-out values from the module can never have ranges strictly before any
|
||||
// other value.
|
||||
if (a.live_out_of_module()) {
|
||||
VLOG(4) << "a is live out of module";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Live-out values of computations can never have ranges strictly before any
|
||||
// other value in the computation (including values nested in
|
||||
// subcomputations).
|
||||
if (a.live_out_of_computation() &&
|
||||
call_graph_->InstructionIsNestedIn(b.defining_instruction(),
|
||||
a.defining_instruction()->parent())) {
|
||||
VLOG(4) << "a is live out of computation containing b";
|
||||
return false;
|
||||
}
|
||||
|
||||
// All uses of 'a' must be before 'b' is defined.
|
||||
for (const HloUse& use : a.uses()) {
|
||||
if (!UseIsBeforeValueDefinition(use, b, dataflow)) {
|
||||
|
@ -71,7 +71,7 @@ HloValue::HloValue(HloValue::Id id, HloInstruction* instruction,
|
||||
const ShapeIndex& index, bool is_phi)
|
||||
: id_(id), is_phi_(is_phi) {
|
||||
// The defining position is always the first element in the positions_ vector.
|
||||
AddPosition(instruction, index);
|
||||
positions_.push_back(HloPosition{instruction, index});
|
||||
}
|
||||
|
||||
bool HloValue::operator==(const HloValue& other) const {
|
||||
@ -130,18 +130,14 @@ bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index,
|
||||
CHECK_LE(operand_number, 2);
|
||||
return operand_number == 0 || index.empty();
|
||||
|
||||
case HloOpcode::kCall:
|
||||
case HloOpcode::kTuple:
|
||||
// These instructions always pass through their operands transparently.
|
||||
return false;
|
||||
|
||||
case HloOpcode::kCall:
|
||||
case HloOpcode::kWhile:
|
||||
// Though the while instructions passes through its operands, we return
|
||||
// true because in SSA form there may be a Phi at the parameter of the
|
||||
// while which is considered a use of its incoming value because the Phi
|
||||
// input values are not passed through into the body computation. Because
|
||||
// this function is used in both SSA and non-SSA forms of the analysis
|
||||
// conservatively return true.
|
||||
// Although call and while instructions pass through their operands, they
|
||||
// are considered uses.
|
||||
return true;
|
||||
|
||||
default:
|
||||
@ -151,103 +147,58 @@ bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index,
|
||||
|
||||
} // namespace
|
||||
|
||||
void HloValue::AddPosition(HloInstruction* instruction,
|
||||
const ShapeIndex& index) {
|
||||
HloPosition new_position{instruction, index};
|
||||
void HloValue::SetPositionsAndComputeUses(
|
||||
tensorflow::gtl::ArraySlice<HloPosition> positions) {
|
||||
CHECK_EQ(positions_.size(), 1) << "SetPositions should only be called once.";
|
||||
|
||||
// The new position must not already exist in positions_.
|
||||
// The positions must be unique and should not contain the defining position
|
||||
// as this is added at construction time.
|
||||
for (const HloPosition& position_a : positions) {
|
||||
DCHECK_NE(position_a, defining_position());
|
||||
for (const HloPosition& position_b : positions) {
|
||||
if (&position_a != &position_b) {
|
||||
DCHECK_NE(position_a, position_b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
positions_.insert(positions_.end(), positions.begin(), positions.end());
|
||||
|
||||
// Gather the computation roots at which this value appears.
|
||||
tensorflow::gtl::FlatSet<HloInstruction*> root_positions;
|
||||
for (const HloPosition& position : positions_) {
|
||||
DCHECK_NE(position, new_position);
|
||||
}
|
||||
|
||||
positions_.push_back(std::move(new_position));
|
||||
|
||||
// Update uses.
|
||||
for (HloInstruction* user : instruction->users()) {
|
||||
for (int64 operand_number : user->OperandIndices(instruction)) {
|
||||
if (MayUseOperandValue(operand_number, index, user)) {
|
||||
HloUse new_use{user, operand_number, index};
|
||||
|
||||
// The new use must not already exist in uses_.
|
||||
for (const HloUse& use : uses_) {
|
||||
DCHECK_NE(use, new_use);
|
||||
}
|
||||
|
||||
uses_.push_back(std::move(new_use));
|
||||
}
|
||||
if (position.instruction ==
|
||||
position.instruction->parent()->root_instruction()) {
|
||||
root_positions.insert(position.instruction);
|
||||
}
|
||||
}
|
||||
|
||||
// Update liveout status of this HloValue.
|
||||
const HloModule& module = *instruction->parent()->parent();
|
||||
if (instruction == module.entry_computation()->root_instruction()) {
|
||||
live_out_of_module_ = true;
|
||||
}
|
||||
|
||||
if (instruction == instruction->parent()->root_instruction()) {
|
||||
live_out_of_computation_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
void HloValue::RemovePosition(HloInstruction* instruction,
|
||||
const ShapeIndex& index) {
|
||||
// The defining position cannot be removed.
|
||||
CHECK(!(instruction == defining_instruction() && index == defining_index()));
|
||||
|
||||
int64 size_before = positions_.size();
|
||||
positions_.erase(
|
||||
std::remove_if(positions_.begin(), positions_.end(),
|
||||
[instruction, &index](const HloPosition& position) {
|
||||
return position.instruction == instruction &&
|
||||
position.index == index;
|
||||
}),
|
||||
positions_.end());
|
||||
// Only a single position should have been removed.
|
||||
CHECK_EQ(positions_.size(), size_before - 1);
|
||||
|
||||
// Update uses which referred to this position.
|
||||
uses_.erase(std::remove_if(uses_.begin(), uses_.end(),
|
||||
[instruction, &index](const HloUse& use) {
|
||||
return use.instruction->operand(
|
||||
use.operand_number) == instruction &&
|
||||
use.operand_index == index;
|
||||
}),
|
||||
uses_.end());
|
||||
|
||||
// Returns whether this value is contained in the given instruction's output.
|
||||
auto is_contained_in = [this](const HloInstruction* instruction) {
|
||||
for (const HloPosition& position : positions()) {
|
||||
if (position.instruction == instruction) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
const HloModule& module = *instruction->parent()->parent();
|
||||
if (instruction == module.entry_computation()->root_instruction()) {
|
||||
// Value has been removed from a position in the entry root instruction.
|
||||
live_out_of_module_ =
|
||||
is_contained_in(module.entry_computation()->root_instruction());
|
||||
}
|
||||
if (instruction == defining_instruction()->parent()->root_instruction()) {
|
||||
// Value has been removed from the root of the computation the value has
|
||||
// been defined in.
|
||||
live_out_of_computation_ =
|
||||
is_contained_in(defining_instruction()->parent()->root_instruction());
|
||||
}
|
||||
}
|
||||
|
||||
void HloValue::RecomputeUses() {
|
||||
uses_.clear();
|
||||
for (const HloPosition& position : positions()) {
|
||||
// Build vector of HloUses for the value.
|
||||
for (const HloPosition& position : positions_) {
|
||||
for (HloInstruction* user : position.instruction->users()) {
|
||||
for (int64 operand_number : user->OperandIndices(position.instruction)) {
|
||||
if (MayUseOperandValue(operand_number, position.index, user)) {
|
||||
uses_.push_back(HloUse{user, operand_number, position.index});
|
||||
// Root instructions of computations are considered to be uses whether
|
||||
// or not the root instruction itself actually uses the value.
|
||||
if (MayUseOperandValue(operand_number, position.index, user) ||
|
||||
ContainsKey(root_positions, user)) {
|
||||
HloUse new_use{user, operand_number, position.index};
|
||||
|
||||
// The new use must not already exist in uses_.
|
||||
for (const HloUse& use : uses_) {
|
||||
DCHECK_NE(use, new_use);
|
||||
}
|
||||
|
||||
uses_.push_back(std::move(new_use));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update liveout status of this HloValue.
|
||||
const HloModule& module = *position.instruction->parent()->parent();
|
||||
if (position.instruction ==
|
||||
module.entry_computation()->root_instruction()) {
|
||||
live_out_of_module_ = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -121,6 +121,12 @@ class HloValue {
|
||||
HloValue(Id id, HloInstruction* instruction, const ShapeIndex& index,
|
||||
bool is_phi = false);
|
||||
|
||||
// Sets the positions in the module at which the HloValue appears. Updates
|
||||
// uses. Should be called once and only once. The defining position should not
|
||||
// be included in 'positions' as this is set at construction time.
|
||||
void SetPositionsAndComputeUses(
|
||||
tensorflow::gtl::ArraySlice<HloPosition> positions);
|
||||
|
||||
// Return a unique identifier for this HloValue. This value is used for stable
|
||||
// sorting and iteration
|
||||
Id id() const { return id_; }
|
||||
@ -143,28 +149,15 @@ class HloValue {
|
||||
// Return the shape of this HloValue.
|
||||
const Shape& shape() const { return defining_position().shape(); }
|
||||
|
||||
// Add or remove a position at which the HloValue appears. The definition
|
||||
// position can not be removed. The uses of the HloValue are updated.
|
||||
void AddPosition(HloInstruction* instruction, const ShapeIndex& index);
|
||||
void RemovePosition(HloInstruction* instruction, const ShapeIndex& index);
|
||||
|
||||
// Remove all positions except the defining position. Updates uses.
|
||||
void ClearPositions();
|
||||
|
||||
// Return all positions of the HloValue in the module.
|
||||
const std::vector<HloPosition>& positions() const { return positions_; }
|
||||
|
||||
// Return all uses of the HloValue.
|
||||
const std::vector<HloUse>& uses() const { return uses_; }
|
||||
|
||||
void RecomputeUses();
|
||||
|
||||
// Get whether this HloValue is live out of the module.
|
||||
bool live_out_of_module() const { return live_out_of_module_; }
|
||||
|
||||
// Get whether this HloValue is live out of the computation it is defined in.
|
||||
bool live_out_of_computation() const { return live_out_of_computation_; }
|
||||
|
||||
bool operator==(const HloValue& other) const;
|
||||
bool operator!=(const HloValue& other) const;
|
||||
|
||||
|
@ -92,6 +92,7 @@ namespace xla {
|
||||
case HloOpcode::kBatchNormInference:
|
||||
case HloOpcode::kBatchNormGrad:
|
||||
case HloOpcode::kCall:
|
||||
case HloOpcode::kConditional:
|
||||
case HloOpcode::kConvolution:
|
||||
case HloOpcode::kCrossReplicaSum:
|
||||
case HloOpcode::kCustomCall:
|
||||
|
@ -83,7 +83,7 @@ void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo,
|
||||
if (std::find(parameter_instructions.begin(),
|
||||
parameter_instructions.end(),
|
||||
&hlo) != parameter_instructions.end()) {
|
||||
array->AddInvariantLoad(llvm::MDNode::get(*context_, /*MDs=*/{}));
|
||||
array->MarkInvariantOverWholeProgram(context_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -256,10 +256,10 @@ void IrArray::AnnotateLoadStoreInstructionWithMetadata(
|
||||
llvm::Instruction* instruction) const {
|
||||
CHECK(llvm::isa<llvm::LoadInst>(instruction) ||
|
||||
llvm::isa<llvm::StoreInst>(instruction));
|
||||
CHECK(!llvm::isa<llvm::StoreInst>(instruction) || !is_invariant_)
|
||||
<< "Trying to create a store to an invariant IRArray.";
|
||||
|
||||
for (const auto& kind_md_pair : metadata_) {
|
||||
CHECK(kind_md_pair.first != llvm::LLVMContext::MD_invariant_load ||
|
||||
llvm::isa<llvm::LoadInst>(instruction));
|
||||
instruction->setMetadata(kind_md_pair.first, kind_md_pair.second);
|
||||
}
|
||||
}
|
||||
|
@ -229,9 +229,33 @@ class IrArray {
|
||||
AddMetadata(llvm::LLVMContext::MD_noalias, noalias);
|
||||
}
|
||||
|
||||
void AddInvariantLoad(llvm::MDNode* invariant_load) {
|
||||
CHECK_NE(invariant_load, nullptr);
|
||||
AddMetadata(llvm::LLVMContext::MD_invariant_load, invariant_load);
|
||||
// Promises LLVM that the data pointed to by this IrArray never changes after
|
||||
// it's first loaded.
|
||||
//
|
||||
// The temporal scope of this promise is the "whole program" from LLVM's point
|
||||
// of view, but how this translates to HLOs differs between backends.
|
||||
//
|
||||
// In the single-threaded CPU backend, we emit one function that
|
||||
// runs all the HLOs in sequence, so the whole program is the whole HLO
|
||||
// module.
|
||||
//
|
||||
// In the GPU backend, we emit one GPU kernel per top-level HLO (i.e. per HLO
|
||||
// in the entry computation). From LLVM's perspective, launching a new kernel
|
||||
// is like launching a new program, and so the whole program is one top-level
|
||||
// HLO. Since the scope of the promise is smaller than in the CPU backend, we
|
||||
// can mark more things as invariant in the GPU backend.
|
||||
//
|
||||
// Marking loads as invariant is particularly helpful on GPUs because
|
||||
// invariant loads can be lowered to PTX ld.global.nc (equivalent to CUDA's
|
||||
// __ldg intrinsic). These loads use a special cache, and can be
|
||||
// significantly faster than regular loads.
|
||||
void MarkInvariantOverWholeProgram(llvm::LLVMContext* context) {
|
||||
if (is_invariant_) {
|
||||
return;
|
||||
}
|
||||
is_invariant_ = true;
|
||||
AddMetadata(llvm::LLVMContext::MD_invariant_load,
|
||||
llvm::MDNode::get(*context, {}));
|
||||
}
|
||||
|
||||
const std::map<int, llvm::MDNode*>& metadata() const { return metadata_; }
|
||||
@ -261,6 +285,8 @@ class IrArray {
|
||||
// loads/stores for this array. They keys are the metadata kinds and the
|
||||
// values are the metadata nodes.
|
||||
std::map<int, llvm::MDNode*> metadata_;
|
||||
|
||||
bool is_invariant_ = false;
|
||||
};
|
||||
|
||||
} // namespace llvm_ir
|
||||
|
@ -63,6 +63,14 @@ void ShapedBuffer::clear() {
|
||||
}
|
||||
}
|
||||
|
||||
void ShapedBuffer::AddBufferAtIndex(
|
||||
const perftools::gputools::DeviceMemoryBase& buffer,
|
||||
const ShapeIndex& shape_index) {
|
||||
*mutable_shape_index_to_buffer_entry()->mutable_element(shape_index) =
|
||||
buffers().size();
|
||||
mutable_buffers()->push_back(buffer);
|
||||
}
|
||||
|
||||
const se::DeviceMemoryBase& ShapedBuffer::buffer(
|
||||
const ShapeIndex& index) const {
|
||||
return buffers_[shape_index_to_buffer_entry_.element(index)];
|
||||
|
@ -75,6 +75,10 @@ class ShapedBuffer {
|
||||
// Set all device memory pointers in the object to null.
|
||||
void clear();
|
||||
|
||||
// Adds a new buffer at the given shape index.
|
||||
void AddBufferAtIndex(const perftools::gputools::DeviceMemoryBase& buffer,
|
||||
const ShapeIndex& shape_index);
|
||||
|
||||
protected:
|
||||
// The shape of the device buffer with layout.
|
||||
const Shape shape_;
|
||||
|
@ -28,12 +28,9 @@ limitations under the License.
|
||||
namespace se = ::perftools::gputools;
|
||||
|
||||
namespace xla {
|
||||
|
||||
/* static */ tensorflow::mutex*
|
||||
TransferManager::platform_transfer_manager_mutex() {
|
||||
static tensorflow::mutex* m = new tensorflow::mutex;
|
||||
return m;
|
||||
}
|
||||
/* static */ tensorflow::mutex
|
||||
TransferManager::platform_transfer_manager_mutex_(
|
||||
tensorflow::LINKER_INITIALIZED);
|
||||
|
||||
/* static */ std::map<perftools::gputools::Platform::Id,
|
||||
TransferManager::State>*
|
||||
@ -47,7 +44,7 @@ TransferManager::GetPlatformTransferManagers() {
|
||||
se::Platform::Id platform_id,
|
||||
TransferManagerCreationFunction creation_function) {
|
||||
tensorflow::mutex_lock lock(
|
||||
*TransferManager::platform_transfer_manager_mutex());
|
||||
TransferManager::platform_transfer_manager_mutex_);
|
||||
auto* managers = GetPlatformTransferManagers();
|
||||
CHECK(managers->find(platform_id) == managers->end());
|
||||
(*managers)[platform_id].creation_function = creation_function;
|
||||
@ -56,7 +53,7 @@ TransferManager::GetPlatformTransferManagers() {
|
||||
/* static */ StatusOr<TransferManager*> TransferManager::GetForPlatform(
|
||||
const se::Platform* platform) {
|
||||
tensorflow::mutex_lock lock(
|
||||
*TransferManager::platform_transfer_manager_mutex());
|
||||
TransferManager::platform_transfer_manager_mutex_);
|
||||
auto* managers = GetPlatformTransferManagers();
|
||||
|
||||
auto it = managers->find(platform->id());
|
||||
|
@ -158,11 +158,8 @@ class TransferManager {
|
||||
const perftools::gputools::Platform* platform);
|
||||
|
||||
private:
|
||||
// Routine that returns the mutex that guards the
|
||||
// platform-to-transfer manager map. Done as a routine to
|
||||
// ensure correct initialization ordering, since RegisterTransferManager
|
||||
// can be called during program initialization time.
|
||||
static tensorflow::mutex* platform_transfer_manager_mutex();
|
||||
// The mutex that guards the platform-to-transfer manager map.
|
||||
static tensorflow::mutex platform_transfer_manager_mutex_;
|
||||
|
||||
// State kept for each kind of TransferManager. Registration functions
|
||||
// set up creation_function, and then we use that to lazily create
|
||||
|
@ -592,10 +592,10 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
|
||||
return sizeof(uint32);
|
||||
case U64:
|
||||
return sizeof(uint64);
|
||||
case F16:
|
||||
return sizeof(float) / 2;
|
||||
case BF16:
|
||||
return sizeof(float) / 2;
|
||||
case F16:
|
||||
return sizeof(float) / 2;
|
||||
case F32:
|
||||
return sizeof(float);
|
||||
case F64:
|
||||
|
@ -69,7 +69,10 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_verifier",
|
||||
"//tensorflow/compiler/xla/service:transfer_manager",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_headers_lib",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -248,5 +248,6 @@ def generate_backend_test_macros(backends=[]):
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:regexp_internal",
|
||||
"//tensorflow/core:test",
|
||||
])
|
||||
|
@ -21,12 +21,13 @@ limitations under the License.
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/regexp.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
// Mapping from test name; i.e. MyTest.MyTestCase to platforms on which it is
|
||||
// disabled.
|
||||
// disabled - a sequence of regexps.
|
||||
using ManifestT = std::unordered_map<string, std::vector<string>>;
|
||||
|
||||
ManifestT ReadManifest() {
|
||||
@ -66,9 +67,6 @@ ManifestT ReadManifest() {
|
||||
|
||||
string PrependDisabledIfIndicated(const string& test_case_name,
|
||||
const string& test_name) {
|
||||
// TODO(leary): this code reads the manifest for every test case instantiated
|
||||
// in every file. Consider switching to a singleton or using a compile-time
|
||||
// genrule instead.
|
||||
ManifestT manifest = ReadManifest();
|
||||
|
||||
// First try full match: test_case_name.test_name
|
||||
@ -83,11 +81,13 @@ string PrependDisabledIfIndicated(const string& test_case_name,
|
||||
}
|
||||
}
|
||||
|
||||
// Expect a full match vs. one of the platform regexps to disable the test.
|
||||
const std::vector<string>& disabled_platforms = it->second;
|
||||
string platform_string = XLA_PLATFORM;
|
||||
if (std::find(disabled_platforms.begin(), disabled_platforms.end(),
|
||||
platform_string) != disabled_platforms.end()) {
|
||||
return "DISABLED_" + test_name;
|
||||
for (const auto& s : disabled_platforms) {
|
||||
if (RE2::FullMatch(/*text=*/platform_string, /*re=*/s)) {
|
||||
return "DISABLED_" + test_name;
|
||||
}
|
||||
}
|
||||
|
||||
// We didn't hit in the disabled manifest entries, so don't disable it.
|
||||
|
@ -66,8 +66,10 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Reads a disabled manifest file (and retains it as a singleton) to resolve
|
||||
// whether test cases should be disabled on a particular platform.
|
||||
// Reads a disabled manifest file to resolve whether test cases should be
|
||||
// disabled on a particular platform. For a test that should be disabled,
|
||||
// returns DISABLED_ prepended to its name; otherwise returns the test name
|
||||
// unmodified.
|
||||
string PrependDisabledIfIndicated(const string& test_case_name,
|
||||
const string& test_name);
|
||||
|
||||
|
@ -14,8 +14,9 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
|
||||
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -46,6 +47,44 @@ void PopulateWithRandomIntegralData(Literal* literal) {
|
||||
}));
|
||||
}
|
||||
|
||||
bool LooksLikeSum(const HloInstruction& instruction) {
|
||||
return instruction.opcode() == HloOpcode::kAdd &&
|
||||
instruction.operand(0)->opcode() == HloOpcode::kParameter &&
|
||||
instruction.operand(1)->opcode() == HloOpcode::kParameter &&
|
||||
instruction.operand(0) != instruction.operand(1);
|
||||
}
|
||||
|
||||
// Given an instruction and operand number, replace the given operand with
|
||||
// a Literal Constant Zero. Handle the case of a fusion instruction by
|
||||
// replacing the fusion's parent's parameter with a Literal Constant Zero,
|
||||
// unless the fusion's parent is itself a fusion.
|
||||
Status MaybeReplaceParameterInputWithZero(HloInstruction* const instruction,
|
||||
const int64 operand_number) {
|
||||
CHECK_LT(operand_number, instruction->operand_count());
|
||||
if (instruction->operand(operand_number)->opcode() != HloOpcode::kParameter) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
HloComputation* const computation = instruction->parent();
|
||||
std::unique_ptr<HloInstruction> zero = HloInstruction::CreateConstant(
|
||||
MakeUnique<Literal>(Literal::Zero(instruction->shape().element_type())));
|
||||
|
||||
if (computation->IsFusionComputation()) {
|
||||
HloInstruction* const fusion_instruction = computation->FusionInstruction();
|
||||
if (fusion_instruction->IsFused()) {
|
||||
return Unimplemented(
|
||||
"Unable to replace fused parameter of fusion instruction");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(fusion_instruction->ReplaceOperandWith(
|
||||
instruction->operand(operand_number)->parameter_number(),
|
||||
fusion_instruction->parent()->AddInstruction(std::move(zero))));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(
|
||||
operand_number, computation->AddInstruction(std::move(zero))));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape) {
|
||||
@ -117,4 +156,32 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
|
||||
return std::move(arguments);
|
||||
}
|
||||
|
||||
Status ReplaceInitsWithConstants(HloModule* const module) {
|
||||
for (HloComputation* const computation : module->computations()) {
|
||||
for (HloInstruction* const instruction : computation->instructions()) {
|
||||
const HloOpcode opcode = instruction->opcode();
|
||||
if ((opcode == HloOpcode::kReduce ||
|
||||
opcode == HloOpcode::kReduceWindow) &&
|
||||
LooksLikeSum(*instruction->to_apply()->root_instruction())) {
|
||||
TF_RETURN_IF_ERROR(MaybeReplaceParameterInputWithZero(instruction, 1));
|
||||
} else if (opcode == HloOpcode::kSelectAndScatter &&
|
||||
LooksLikeSum(*instruction->scatter()->root_instruction())) {
|
||||
TF_RETURN_IF_ERROR(MaybeReplaceParameterInputWithZero(instruction, 2));
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status VerifyHloModule(const perftools::gputools::Platform& platform,
|
||||
HloModule* const module) {
|
||||
return HloVerifier(
|
||||
std::bind(
|
||||
&TransferManager::GetByteSizeRequirement,
|
||||
TransferManager::GetForPlatform(&platform).ConsumeValueOrDie(),
|
||||
std::placeholders::_1))
|
||||
.Run(module)
|
||||
.status();
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/stream_executor/platform.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -62,6 +63,16 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape);
|
||||
StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
|
||||
const HloModule& module);
|
||||
|
||||
// Reductions using Adds, ReduceWindow, and SelectAndScatter, require their
|
||||
// init_value to be replaced with the constant 0.0f when testing, otherwise we
|
||||
// may generate a bad init_value when looking at the op in isolation.
|
||||
Status ReplaceInitsWithConstants(HloModule* const module);
|
||||
|
||||
// Check that a given module satisfies various constraints before trying to
|
||||
// execute it.
|
||||
Status VerifyHloModule(const perftools::gputools::Platform& platform,
|
||||
HloModule* const module);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_UTILS_H_
|
||||
|
@ -776,11 +776,32 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
|
||||
shape, *fusion_kind, operands, *fusion_computation));
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kInfeed: {
|
||||
optional<string> config;
|
||||
attrs["infeed_config"] = {/*required=*/false, AttrTy::kString, &config};
|
||||
if (!ParseOperands(&operands, /*expected_size=*/0) ||
|
||||
!ParseAttributes(attrs)) {
|
||||
return false;
|
||||
}
|
||||
instruction = builder->AddInstruction(
|
||||
HloInstruction::CreateInfeed(shape, config ? *config : ""));
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kOutfeed: {
|
||||
optional<string> config;
|
||||
attrs["outfeed_config"] = {/*required=*/false, AttrTy::kString, &config};
|
||||
if (!ParseOperands(&operands, /*expected_size=*/1) ||
|
||||
!ParseAttributes(attrs)) {
|
||||
return false;
|
||||
}
|
||||
instruction = builder->AddInstruction(HloInstruction::CreateOutfeed(
|
||||
shape, operands[0], config ? *config : ""));
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kConditional:
|
||||
case HloOpcode::kCustomCall:
|
||||
case HloOpcode::kReducePrecision:
|
||||
case HloOpcode::kRng:
|
||||
case HloOpcode::kInfeed:
|
||||
case HloOpcode::kOutfeed:
|
||||
case HloOpcode::kTrace:
|
||||
return TokenError(StrCat("parsing not yet implemented for op: ",
|
||||
HloOpcodeString(opcode)));
|
||||
@ -805,7 +826,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
|
||||
instruction->set_metadata(*metadata);
|
||||
}
|
||||
return AddInstruction(name, instruction);
|
||||
}
|
||||
} // NOLINT(readability/fn_size)
|
||||
|
||||
// ::= '{' (single_sharding | tuple_sharding) '}'
|
||||
//
|
||||
|
@ -560,6 +560,20 @@ ENTRY %fusion.v3 () -> f32[3,2,1,1] {
|
||||
ROOT %fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %constant, f32[2]{0} %constant.1), kind=kLoop, calls=%fused_computation
|
||||
}
|
||||
|
||||
)"
|
||||
},
|
||||
// infeed/outfeed
|
||||
{
|
||||
"InfeedOutfeed",
|
||||
R"(HloModule outfeed_module:
|
||||
|
||||
ENTRY %InfeedToOutfeed () -> (u32[3], pred[]) {
|
||||
%infeed = (u32[3]{0}, pred[]) infeed()
|
||||
%outfeed = () outfeed((u32[3]{0}, pred[]) %infeed)
|
||||
ROOT %infeed.1 = (u32[3]{0}, pred[]) infeed()
|
||||
%outfeed.1 = () outfeed((u32[3]{0}, pred[]) %infeed.1)
|
||||
}
|
||||
|
||||
)"
|
||||
}
|
||||
});
|
||||
@ -866,7 +880,7 @@ TEST_F(HloParserTest, CommaBetweenSubAttributes) {
|
||||
const string original = R"(HloModule test_comma_module:
|
||||
|
||||
ENTRY %test_comma.v4 () -> f32[] {
|
||||
ROOT %constant = f32[] constant(-4.2), metadata={source_line=5, op_type="const"}
|
||||
ROOT %constant = f32[] constant(-4.2), metadata={source_line=5, op_type="::const"}
|
||||
}
|
||||
|
||||
)";
|
||||
|
@ -17,3 +17,5 @@ def xla_proto_library(name, srcs=[], deps=[], visibility=None, testonly=0):
|
||||
protoc="@protobuf_archive//:protoc",
|
||||
testonly=testonly,
|
||||
visibility=visibility,)
|
||||
|
||||
ORC_JIT_MEMORY_MAPPER_TARGETS = []
|
||||
|
@ -19,6 +19,7 @@ py_library(
|
||||
srcs = ["__init__.py"] + glob(["python/ops/*.py"]),
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/distributions:distributions_py",
|
||||
"//tensorflow/contrib/framework:framework_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
@ -32,7 +33,6 @@ py_library(
|
||||
"//tensorflow/python:random_ops",
|
||||
"//tensorflow/python:state_ops",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/ops/distributions",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
@ -99,6 +99,25 @@ cuda_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "layers_dense_variational_test",
|
||||
size = "small",
|
||||
srcs = ["python/kernel_tests/layers_dense_variational_test.py"],
|
||||
additional_deps = [
|
||||
":bayesflow_py",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/contrib/distributions:distributions_py",
|
||||
"//tensorflow/python/ops/distributions",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:linalg_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:nn_ops",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "monte_carlo_test",
|
||||
size = "small",
|
||||
@ -160,6 +179,27 @@ cuda_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "sgld_optimizer_test",
|
||||
size = "small",
|
||||
srcs = ["python/kernel_tests/sgld_optimizer_test.py"],
|
||||
additional_deps = [
|
||||
":bayesflow_py",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/contrib/distributions:distributions_py",
|
||||
"//tensorflow/contrib/layers:layers_py",
|
||||
"//tensorflow/python/ops/distributions",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:random_seed",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
|
@ -25,16 +25,28 @@ from tensorflow.contrib.bayesflow.python.ops import csiszar_divergence
|
||||
from tensorflow.contrib.bayesflow.python.ops import custom_grad
|
||||
from tensorflow.contrib.bayesflow.python.ops import halton_sequence
|
||||
from tensorflow.contrib.bayesflow.python.ops import hmc
|
||||
from tensorflow.contrib.bayesflow.python.ops import layers
|
||||
from tensorflow.contrib.bayesflow.python.ops import metropolis_hastings
|
||||
from tensorflow.contrib.bayesflow.python.ops import monte_carlo
|
||||
from tensorflow.contrib.bayesflow.python.ops import optimizers
|
||||
# pylint: enable=unused-import,line-too-long
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
|
||||
_allowed_symbols = ['csiszar_divergence', 'custom_grad', 'entropy',
|
||||
'metropolis_hastings', 'monte_carlo', 'halton_sequence',
|
||||
'hmc', 'special_math', 'stochastic_variables',
|
||||
'variational_inference']
|
||||
_allowed_symbols = [
|
||||
'csiszar_divergence',
|
||||
'custom_grad',
|
||||
'entropy',
|
||||
'halton_sequence',
|
||||
'hmc',
|
||||
'layers',
|
||||
'metropolis_hastings',
|
||||
'monte_carlo',
|
||||
'optimizers',
|
||||
'special_math',
|
||||
'stochastic_variables',
|
||||
'variational_inference',
|
||||
]
|
||||
|
||||
remove_undocumented(__name__, _allowed_symbols)
|
||||
|
@ -0,0 +1,304 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for dense Bayesian layers."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.bayesflow.python.ops import layers_dense_variational_impl as prob_layers_lib
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops.distributions import normal as normal_lib
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class Counter(object):
|
||||
"""Helper class to manage incrementing a counting `int`."""
|
||||
|
||||
def __init__(self):
|
||||
self._value = -1
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self._value
|
||||
|
||||
def __call__(self):
|
||||
self._value += 1
|
||||
return self._value
|
||||
|
||||
|
||||
class MockDistribution(normal_lib.Normal):
|
||||
"""Monitors DenseVariational calls to the underlying distribution."""
|
||||
|
||||
def __init__(self, result_sample, result_log_prob, loc=None, scale=None):
|
||||
self.result_sample = result_sample
|
||||
self.result_log_prob = result_log_prob
|
||||
self.result_loc = loc
|
||||
self.result_scale = scale
|
||||
self.called_log_prob = Counter()
|
||||
self.called_sample = Counter()
|
||||
self.called_loc = Counter()
|
||||
self.called_scale = Counter()
|
||||
|
||||
def log_prob(self, *args, **kwargs):
|
||||
self.called_log_prob()
|
||||
return self.result_log_prob
|
||||
|
||||
def sample(self, *args, **kwargs):
|
||||
self.called_sample()
|
||||
return self.result_sample
|
||||
|
||||
@property
|
||||
def loc(self):
|
||||
self.called_loc()
|
||||
return self.result_loc
|
||||
|
||||
@property
|
||||
def scale(self):
|
||||
self.called_scale()
|
||||
return self.result_scale
|
||||
|
||||
|
||||
class MockKLDivergence(object):
|
||||
"""Monitors DenseVariational calls to the divergence implementation."""
|
||||
|
||||
def __init__(self, result):
|
||||
self.result = result
|
||||
self.args = []
|
||||
self.called = Counter()
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
self.called()
|
||||
self.args.append(args)
|
||||
return self.result
|
||||
|
||||
|
||||
class DenseVariationalLocalReparametrization(test.TestCase):
|
||||
|
||||
def testKLPenaltyKernel(self):
|
||||
with self.test_session():
|
||||
dense_vi = prob_layers_lib.DenseVariational(units=2)
|
||||
inputs = random_ops.random_uniform([2, 3], seed=1)
|
||||
|
||||
# No keys.
|
||||
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
|
||||
self.assertEqual(len(loss_keys), 0)
|
||||
self.assertListEqual(dense_vi.losses, loss_keys)
|
||||
|
||||
_ = dense_vi(inputs)
|
||||
|
||||
# Yes keys.
|
||||
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
|
||||
self.assertEqual(len(loss_keys), 1)
|
||||
self.assertListEqual(dense_vi.losses, loss_keys)
|
||||
|
||||
def testKLPenaltyBoth(self):
|
||||
def _make_normal(dtype, *args): # pylint: disable=unused-argument
|
||||
return normal_lib.Normal(
|
||||
loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.))
|
||||
with self.test_session():
|
||||
dense_vi = prob_layers_lib.DenseVariational(
|
||||
units=2,
|
||||
bias_posterior_fn=prob_layers_lib.default_mean_field_normal_fn(),
|
||||
bias_prior_fn=_make_normal)
|
||||
inputs = random_ops.random_uniform([2, 3], seed=1)
|
||||
|
||||
# No keys.
|
||||
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
|
||||
self.assertEqual(len(loss_keys), 0)
|
||||
self.assertListEqual(dense_vi.losses, loss_keys)
|
||||
|
||||
_ = dense_vi(inputs)
|
||||
|
||||
# Yes keys.
|
||||
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
|
||||
self.assertEqual(len(loss_keys), 2)
|
||||
self.assertListEqual(dense_vi.losses, loss_keys)
|
||||
|
||||
def testVariationalNonLocal(self):
|
||||
batch_size, in_size, out_size = 2, 3, 4
|
||||
with self.test_session() as sess:
|
||||
seed = Counter()
|
||||
inputs = random_ops.random_uniform([batch_size, in_size], seed=seed())
|
||||
|
||||
kernel_size = [in_size, out_size]
|
||||
kernel_posterior = MockDistribution(
|
||||
result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()),
|
||||
result_sample=random_ops.random_uniform(kernel_size, seed=seed()))
|
||||
kernel_prior = MockDistribution(
|
||||
result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()),
|
||||
result_sample=random_ops.random_uniform(kernel_size, seed=seed()))
|
||||
kernel_divergence = MockKLDivergence(
|
||||
result=random_ops.random_uniform(kernel_size, seed=seed()))
|
||||
|
||||
bias_size = [out_size]
|
||||
bias_posterior = MockDistribution(
|
||||
result_log_prob=random_ops.random_uniform(bias_size, seed=seed()),
|
||||
result_sample=random_ops.random_uniform(bias_size, seed=seed()))
|
||||
bias_prior = MockDistribution(
|
||||
result_log_prob=random_ops.random_uniform(bias_size, seed=seed()),
|
||||
result_sample=random_ops.random_uniform(bias_size, seed=seed()))
|
||||
bias_divergence = MockKLDivergence(
|
||||
result=random_ops.random_uniform(bias_size, seed=seed()))
|
||||
|
||||
expected_outputs = (
|
||||
math_ops.matmul(inputs, kernel_posterior.result_sample) +
|
||||
bias_posterior.result_sample)
|
||||
|
||||
dense_vi = prob_layers_lib.DenseVariational(
|
||||
units=2,
|
||||
kernel_use_local_reparameterization=False,
|
||||
kernel_posterior_fn=lambda *args: kernel_posterior,
|
||||
kernel_posterior_tensor_fn=lambda d: d.sample(seed=42),
|
||||
kernel_prior_fn=lambda *args: kernel_prior,
|
||||
kernel_divergence_fn=kernel_divergence,
|
||||
bias_posterior_fn=lambda *args: bias_posterior,
|
||||
bias_posterior_tensor_fn=lambda d: d.sample(seed=43),
|
||||
bias_prior_fn=lambda *args: bias_prior,
|
||||
bias_divergence_fn=bias_divergence)
|
||||
|
||||
outputs = dense_vi(inputs)
|
||||
|
||||
kl_penalty = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
|
||||
|
||||
[
|
||||
expected_outputs_, actual_outputs_,
|
||||
expected_kernel_, actual_kernel_,
|
||||
expected_kernel_divergence_, actual_kernel_divergence_,
|
||||
expected_bias_, actual_bias_,
|
||||
expected_bias_divergence_, actual_bias_divergence_,
|
||||
] = sess.run([
|
||||
expected_outputs, outputs,
|
||||
kernel_posterior.result_sample, dense_vi.kernel.posterior_tensor,
|
||||
kernel_divergence.result, kl_penalty[0],
|
||||
bias_posterior.result_sample, dense_vi.bias.posterior_tensor,
|
||||
bias_divergence.result, kl_penalty[1],
|
||||
])
|
||||
|
||||
self.assertAllClose(
|
||||
expected_kernel_, actual_kernel_,
|
||||
rtol=1e-6, atol=0.)
|
||||
self.assertAllClose(
|
||||
expected_bias_, actual_bias_,
|
||||
rtol=1e-6, atol=0.)
|
||||
self.assertAllClose(
|
||||
expected_outputs_, actual_outputs_,
|
||||
rtol=1e-6, atol=0.)
|
||||
self.assertAllClose(
|
||||
expected_kernel_divergence_, actual_kernel_divergence_,
|
||||
rtol=1e-6, atol=0.)
|
||||
self.assertAllClose(
|
||||
expected_bias_divergence_, actual_bias_divergence_,
|
||||
rtol=1e-6, atol=0.)
|
||||
|
||||
self.assertAllEqual(
|
||||
[[kernel_posterior, kernel_prior, kernel_posterior.result_sample]],
|
||||
kernel_divergence.args)
|
||||
|
||||
self.assertAllEqual(
|
||||
[[bias_posterior, bias_prior, bias_posterior.result_sample]],
|
||||
bias_divergence.args)
|
||||
|
||||
def testVariationalLocal(self):
|
||||
batch_size, in_size, out_size = 2, 3, 4
|
||||
with self.test_session() as sess:
|
||||
seed = Counter()
|
||||
inputs = random_ops.random_uniform([batch_size, in_size], seed=seed())
|
||||
|
||||
kernel_size = [in_size, out_size]
|
||||
kernel_posterior = MockDistribution(
|
||||
loc=random_ops.random_uniform(kernel_size, seed=seed()),
|
||||
scale=random_ops.random_uniform(kernel_size, seed=seed()),
|
||||
result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()),
|
||||
result_sample=random_ops.random_uniform(kernel_size, seed=seed()))
|
||||
kernel_prior = MockDistribution(
|
||||
result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()),
|
||||
result_sample=random_ops.random_uniform(kernel_size, seed=seed()))
|
||||
kernel_divergence = MockKLDivergence(
|
||||
result=random_ops.random_uniform(kernel_size, seed=seed()))
|
||||
|
||||
bias_size = [out_size]
|
||||
bias_posterior = MockDistribution(
|
||||
result_log_prob=random_ops.random_uniform(bias_size, seed=seed()),
|
||||
result_sample=random_ops.random_uniform(bias_size, seed=seed()))
|
||||
bias_prior = MockDistribution(
|
||||
result_log_prob=random_ops.random_uniform(bias_size, seed=seed()),
|
||||
result_sample=random_ops.random_uniform(bias_size, seed=seed()))
|
||||
bias_divergence = MockKLDivergence(
|
||||
result=random_ops.random_uniform(bias_size, seed=seed()))
|
||||
|
||||
expected_kernel_posterior_affine = normal_lib.Normal(
|
||||
loc=math_ops.matmul(inputs, kernel_posterior.result_loc),
|
||||
scale=math_ops.matmul(
|
||||
inputs**2., kernel_posterior.result_scale**2)**0.5)
|
||||
expected_kernel_posterior_affine_tensor = (
|
||||
expected_kernel_posterior_affine.sample(seed=42))
|
||||
expected_outputs = (expected_kernel_posterior_affine_tensor +
|
||||
bias_posterior.result_sample)
|
||||
|
||||
dense_vi = prob_layers_lib.DenseVariational(
|
||||
units=2,
|
||||
kernel_use_local_reparameterization=True,
|
||||
kernel_posterior_fn=lambda *args: kernel_posterior,
|
||||
kernel_posterior_tensor_fn=lambda d: d.sample(seed=42),
|
||||
kernel_prior_fn=lambda *args: kernel_prior,
|
||||
kernel_divergence_fn=kernel_divergence,
|
||||
bias_posterior_fn=lambda *args: bias_posterior,
|
||||
bias_posterior_tensor_fn=lambda d: d.sample(seed=43),
|
||||
bias_prior_fn=lambda *args: bias_prior,
|
||||
bias_divergence_fn=bias_divergence)
|
||||
|
||||
outputs = dense_vi(inputs)
|
||||
|
||||
kl_penalty = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
|
||||
|
||||
[
|
||||
expected_outputs_, actual_outputs_,
|
||||
expected_kernel_divergence_, actual_kernel_divergence_,
|
||||
expected_bias_, actual_bias_,
|
||||
expected_bias_divergence_, actual_bias_divergence_,
|
||||
] = sess.run([
|
||||
expected_outputs, outputs,
|
||||
kernel_divergence.result, kl_penalty[0],
|
||||
bias_posterior.result_sample, dense_vi.bias.posterior_tensor,
|
||||
bias_divergence.result, kl_penalty[1],
|
||||
])
|
||||
|
||||
self.assertAllClose(
|
||||
expected_bias_, actual_bias_,
|
||||
rtol=1e-6, atol=0.)
|
||||
self.assertAllClose(
|
||||
expected_outputs_, actual_outputs_,
|
||||
rtol=1e-6, atol=0.)
|
||||
self.assertAllClose(
|
||||
expected_kernel_divergence_, actual_kernel_divergence_,
|
||||
rtol=1e-6, atol=0.)
|
||||
self.assertAllClose(
|
||||
expected_bias_divergence_, actual_bias_divergence_,
|
||||
rtol=1e-6, atol=0.)
|
||||
|
||||
self.assertAllEqual(
|
||||
[[kernel_posterior, kernel_prior, None]],
|
||||
kernel_divergence.args)
|
||||
|
||||
self.assertAllEqual(
|
||||
[[bias_posterior, bias_prior, bias_posterior.result_sample]],
|
||||
bias_divergence.args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -0,0 +1,209 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Functional test for GradientDescent."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import math
|
||||
from tensorflow.contrib.bayesflow.python.ops.optimizers import SGLDOptimizer
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class SGLDOptimizerTest(test.TestCase):
|
||||
|
||||
def testBasic(self):
|
||||
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
||||
with self.test_session():
|
||||
var0 = variables.Variable([1.1, 2.1], dtype=dtype)
|
||||
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
|
||||
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
|
||||
grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
|
||||
decay_rate = 0.53
|
||||
sgd_op = SGLDOptimizer(
|
||||
3.0, preconditioner_decay_rate=decay_rate).apply_gradients(
|
||||
zip([grads0, grads1], [var0, var1]))
|
||||
variables.global_variables_initializer().run()
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval())
|
||||
self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval())
|
||||
# Run 1 step of sgd
|
||||
sgd_op.run()
|
||||
# Validate updated params
|
||||
grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate +
|
||||
(1 - decay_rate) * 0.1**2 + 1e-8))
|
||||
self.assertAllCloseAccordingToType(
|
||||
[1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], var0.eval())
|
||||
grads_scaled = (0.5 * 0.01 / math.sqrt(
|
||||
decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8))
|
||||
self.assertAllCloseAccordingToType(
|
||||
[3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], var1.eval())
|
||||
|
||||
def testBasicMultiInstance(self):
|
||||
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
||||
with self.test_session():
|
||||
var0 = variables.Variable([1.1, 2.1], dtype=dtype)
|
||||
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
|
||||
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
|
||||
grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
|
||||
vara = variables.Variable([1.1, 2.1], dtype=dtype)
|
||||
varb = variables.Variable([3.0, 4.0], dtype=dtype)
|
||||
gradsa = constant_op.constant([0.1, 0.1], dtype=dtype)
|
||||
gradsb = constant_op.constant([0.01, 0.01], dtype=dtype)
|
||||
decay_rate = 0.5
|
||||
sgd_optimizer = SGLDOptimizer(3.0, preconditioner_decay_rate=decay_rate)
|
||||
sgd_op = sgd_optimizer.apply_gradients(
|
||||
zip([grads0, grads1], [var0, var1]))
|
||||
sgd_optimizer2 = SGLDOptimizer(
|
||||
3.0, preconditioner_decay_rate=decay_rate)
|
||||
sgd_op2 = sgd_optimizer2.apply_gradients(
|
||||
zip([gradsa, gradsb], [vara, varb]))
|
||||
variables.global_variables_initializer().run()
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval())
|
||||
self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval())
|
||||
self.assertAllCloseAccordingToType([1.1, 2.1], vara.eval())
|
||||
self.assertAllCloseAccordingToType([3.0, 4.0], varb.eval())
|
||||
|
||||
# Run 1 step of sgd
|
||||
sgd_op.run()
|
||||
sgd_op2.run()
|
||||
# Validate updated params
|
||||
grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate +
|
||||
(1 - decay_rate) * 0.1**2 + 1e-8))
|
||||
self.assertAllCloseAccordingToType(
|
||||
[1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], var0.eval())
|
||||
self.assertAllCloseAccordingToType(
|
||||
[1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], vara.eval())
|
||||
|
||||
grads_scaled = (0.5 * 0.01 / math.sqrt(
|
||||
decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8))
|
||||
self.assertAllCloseAccordingToType(
|
||||
[3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], var1.eval())
|
||||
self.assertAllCloseAccordingToType(
|
||||
[3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], varb.eval())
|
||||
self.assertNotEqual(sgd_optimizer.variable_scope,
|
||||
sgd_optimizer2.variable_scope)
|
||||
self.assertNotEqual(sgd_optimizer.variable_scope.name,
|
||||
sgd_optimizer2.variable_scope.name)
|
||||
|
||||
def testTensorLearningRate(self):
|
||||
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
||||
with self.test_session():
|
||||
var0 = variables.Variable([1.1, 2.1], dtype=dtype)
|
||||
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
|
||||
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
|
||||
grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
|
||||
lrate = constant_op.constant(3.0)
|
||||
decay_rate = 0.5
|
||||
sgd_op = SGLDOptimizer(
|
||||
lrate, preconditioner_decay_rate=constant_op.constant(
|
||||
decay_rate)).apply_gradients(
|
||||
zip([grads0, grads1], [var0, var1]))
|
||||
variables.global_variables_initializer().run()
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval())
|
||||
self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval())
|
||||
# Run 1 step of sgd
|
||||
sgd_op.run()
|
||||
# Validate updated params
|
||||
grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate +
|
||||
(1 - decay_rate) * 0.1**2 + 1e-8))
|
||||
self.assertAllCloseAccordingToType(
|
||||
[1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], var0.eval())
|
||||
grads_scaled = (0.5 * 0.01 / math.sqrt(
|
||||
decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8))
|
||||
self.assertAllCloseAccordingToType(
|
||||
[3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], var1.eval())
|
||||
|
||||
def testGradWrtRef(self):
|
||||
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
||||
with self.test_session():
|
||||
opt = SGLDOptimizer(3.0)
|
||||
values = [1.0, 3.0]
|
||||
vars_ = [variables.Variable([v], dtype=dtype) for v in values]
|
||||
grads_and_vars = opt.compute_gradients(vars_[0] + vars_[1], vars_)
|
||||
variables.global_variables_initializer().run()
|
||||
for grad, _ in grads_and_vars:
|
||||
self.assertAllCloseAccordingToType([1.0], grad.eval())
|
||||
|
||||
def testWithGlobalStep(self):
|
||||
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
||||
with self.test_session():
|
||||
global_step = variables.Variable(0, trainable=False)
|
||||
var0 = variables.Variable([1.1, 2.1], dtype=dtype)
|
||||
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
|
||||
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
|
||||
grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
|
||||
decay_rate = 0.1
|
||||
sgd_op = SGLDOptimizer(
|
||||
3.0, preconditioner_decay_rate=decay_rate).apply_gradients(
|
||||
zip([grads0, grads1], [var0, var1]), global_step=global_step)
|
||||
variables.global_variables_initializer().run()
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval())
|
||||
self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval())
|
||||
# Run 1 step of sgd
|
||||
sgd_op.run()
|
||||
|
||||
# Validate updated params and global_step
|
||||
grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate +
|
||||
(1 - decay_rate) * 0.1**2 + 1e-8))
|
||||
self.assertAllCloseAccordingToType(
|
||||
[1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], var0.eval())
|
||||
grads_scaled = (0.5 * 0.01 / math.sqrt(
|
||||
decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8))
|
||||
self.assertAllCloseAccordingToType(
|
||||
[3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], var1.eval())
|
||||
self.assertAllCloseAccordingToType(1, global_step.eval())
|
||||
|
||||
def testSparseBasic(self):
|
||||
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
||||
with self.test_session():
|
||||
var0 = variables.Variable([[1.1], [2.1]], dtype=dtype)
|
||||
var1 = variables.Variable([[3.0], [4.0]], dtype=dtype)
|
||||
grads0 = ops.IndexedSlices(
|
||||
constant_op.constant([0.1], shape=[1, 1], dtype=dtype),
|
||||
constant_op.constant([0]), constant_op.constant([2, 1]))
|
||||
grads1 = ops.IndexedSlices(
|
||||
constant_op.constant([0.01], shape=[1, 1], dtype=dtype),
|
||||
constant_op.constant([1]), constant_op.constant([2, 1]))
|
||||
decay_rate = 0.9
|
||||
sgd_op = SGLDOptimizer(
|
||||
3.0, preconditioner_decay_rate=decay_rate).apply_gradients(
|
||||
zip([grads0, grads1], [var0, var1]))
|
||||
variables.global_variables_initializer().run()
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllCloseAccordingToType([[1.1], [2.1]], var0.eval())
|
||||
self.assertAllCloseAccordingToType([[3.0], [4.0]], var1.eval())
|
||||
# Run 1 step of sgd
|
||||
sgd_op.run()
|
||||
# Validate updated params
|
||||
grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate +
|
||||
(1 - decay_rate) * 0.1**2 + 1e-8))
|
||||
self.assertAllCloseAccordingToType([[1.1 - 3.0 * grads_scaled], [2.1]],
|
||||
var0.eval())
|
||||
grads_scaled = (0.5 * 0.01 / math.sqrt(
|
||||
decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8))
|
||||
self.assertAllCloseAccordingToType(
|
||||
[[3.0 - 3.0 * 0], [4.0 - 3.0 * grads_scaled]], var1.eval())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
37
tensorflow/contrib/bayesflow/python/ops/layers.py
Normal file
37
tensorflow/contrib/bayesflow/python/ops/layers.py
Normal file
@ -0,0 +1,37 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Probabilistic neural layers.
|
||||
|
||||
See ${python/contrib.bayesflow.layers}.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# go/tf-wildcard-import
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.contrib.bayesflow.python.ops.layers_dense_variational_impl import *
|
||||
# pylint: enable=wildcard-import
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
_allowed_symbols = [
|
||||
'DenseVariational',
|
||||
'dense_variational',
|
||||
'default_loc_scale_fn',
|
||||
'default_mean_field_normal_fn',
|
||||
]
|
||||
|
||||
remove_undocumented(__name__, _allowed_symbols)
|
@ -0,0 +1,797 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Dense Bayesian layer using KL-divergence based variational inference.
|
||||
|
||||
@@DenseVariational
|
||||
@@dense_variational
|
||||
|
||||
@@default_loc_scale_fn
|
||||
@@default_mean_field_normal_fn
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.distributions.python.ops import deterministic as deterministic_lib
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.layers import base as layers_lib
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import standard_ops
|
||||
from tensorflow.python.ops.distributions import kullback_leibler as kl_lib
|
||||
from tensorflow.python.ops.distributions import normal as normal_lib
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DenseVariational",
|
||||
"dense_variational",
|
||||
"default_loc_scale_fn",
|
||||
"default_mean_field_normal_fn",
|
||||
]
|
||||
|
||||
|
||||
def default_loc_scale_fn(
|
||||
is_singular=False,
|
||||
loc_initializer=init_ops.random_normal_initializer(stddev=0.1),
|
||||
untransformed_scale_initializer=init_ops.random_normal_initializer(
|
||||
mean=-3., stddev=0.1),
|
||||
loc_regularizer=None,
|
||||
untransformed_scale_regularizer=None,
|
||||
loc_constraint=None,
|
||||
untransformed_scale_constraint=None):
|
||||
"""Makes closure which creates `loc`, `scale` params from `tf.get_variable`.
|
||||
|
||||
This function produces a closure which produces `loc`, `scale` using
|
||||
`tf.get_variable`. The closure accepts the following arguments:
|
||||
|
||||
dtype: Type of parameter's event.
|
||||
shape: Python `list`-like representing the parameter's event shape.
|
||||
name: Python `str` name prepended to any created (or existing)
|
||||
`tf.Variable`s.
|
||||
trainable: Python `bool` indicating all created `tf.Variable`s should be
|
||||
added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`.
|
||||
add_variable_fn: `tf.get_variable`-like `callable` used to create (or
|
||||
access existing) `tf.Variable`s.
|
||||
|
||||
Args:
|
||||
is_singular: Python `bool` indicating if `scale is None`. Default: `False`.
|
||||
loc_initializer: Initializer function for the `loc` parameters.
|
||||
The default is `tf.random_normal_initializer(mean=0., stddev=0.1)`.
|
||||
untransformed_scale_initializer: Initializer function for the `scale`
|
||||
parameters. Default value: `tf.random_normal_initializer(mean=-3.,
|
||||
stddev=0.1)`. This implies the softplus transformed result has mean
|
||||
approximately `0.05` and std. deviation approximately `0.005`.
|
||||
loc_regularizer: Regularizer function for the `loc` parameters.
|
||||
The default (`None`) is to use the `tf.get_variable` default.
|
||||
untransformed_scale_regularizer: Regularizer function for the `scale`
|
||||
parameters. The default (`None`) is to use the `tf.get_variable` default.
|
||||
loc_constraint: An optional projection function to be applied to the
|
||||
loc after being updated by an `Optimizer`. The function must take as input
|
||||
the unprojected variable and must return the projected variable (which
|
||||
must have the same shape). Constraints are not safe to use when doing
|
||||
asynchronous distributed training.
|
||||
The default (`None`) is to use the `tf.get_variable` default.
|
||||
untransformed_scale_constraint: An optional projection function to be
|
||||
applied to the `scale` parameters after being updated by an `Optimizer`
|
||||
(e.g. used to implement norm constraints or value constraints). The
|
||||
function must take as input the unprojected variable and must return the
|
||||
projected variable (which must have the same shape). Constraints are not
|
||||
safe to use when doing asynchronous distributed training. The default
|
||||
(`None`) is to use the `tf.get_variable` default.
|
||||
|
||||
Returns:
|
||||
default_loc_scale_fn: Python `callable` which instantiates `loc`, `scale`
|
||||
parameters from args: `dtype, shape, name, trainable, add_variable_fn`.
|
||||
"""
|
||||
def _fn(dtype, shape, name, trainable, add_variable_fn):
|
||||
"""Creates `loc`, `scale` parameters."""
|
||||
loc = add_variable_fn(
|
||||
name=name + "_loc",
|
||||
shape=shape,
|
||||
initializer=loc_initializer,
|
||||
regularizer=loc_regularizer,
|
||||
constraint=loc_constraint,
|
||||
dtype=dtype,
|
||||
trainable=trainable)
|
||||
if is_singular:
|
||||
return loc, None
|
||||
untransformed_scale = add_variable_fn(
|
||||
name=name + "_untransformed_scale",
|
||||
shape=shape,
|
||||
initializer=untransformed_scale_initializer,
|
||||
regularizer=untransformed_scale_regularizer,
|
||||
constraint=untransformed_scale_constraint,
|
||||
dtype=dtype,
|
||||
trainable=trainable)
|
||||
scale = (np.finfo(dtype.as_numpy_dtype).eps +
|
||||
nn_ops.softplus(untransformed_scale))
|
||||
return loc, scale
|
||||
return _fn
|
||||
|
||||
|
||||
def default_mean_field_normal_fn(
|
||||
is_singular=False,
|
||||
loc_initializer=None,
|
||||
untransformed_scale_initializer=None,
|
||||
loc_regularizer=None,
|
||||
untransformed_scale_regularizer=None,
|
||||
loc_constraint=None,
|
||||
untransformed_scale_constraint=None):
|
||||
"""Creates a function to build Normal distributions with trainable params.
|
||||
|
||||
This function produces a closure which produces `tf.distributions.Normal`
|
||||
parameterized by a loc` and `scale` each created using `tf.get_variable`. The
|
||||
produced closure accepts the following arguments:
|
||||
|
||||
name: Python `str` name prepended to any created (or existing)
|
||||
`tf.Variable`s.
|
||||
shape: Python `list`-like representing the parameter's event shape.
|
||||
dtype: Type of parameter's event.
|
||||
trainable: Python `bool` indicating all created `tf.Variable`s should be
|
||||
added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`.
|
||||
add_variable_fn: `tf.get_variable`-like `callable` used to create (or
|
||||
access existing) `tf.Variable`s.
|
||||
|
||||
Args:
|
||||
is_singular: Python `bool` if `True`, forces the special case limit of
|
||||
`scale->0`, i.e., a `Deterministic` distribution.
|
||||
loc_initializer: Initializer function for the `loc` parameters.
|
||||
If `None` (default), values are initialized using the default
|
||||
initializer used by `tf.get_variable`.
|
||||
untransformed_scale_initializer: Initializer function for the `scale`
|
||||
parameters. If `None` (default), values are initialized using the default
|
||||
initializer used by `tf.get_variable`.
|
||||
loc_regularizer: Regularizer function for the `loc` parameters.
|
||||
untransformed_scale_regularizer: Regularizer function for the `scale`
|
||||
parameters.
|
||||
loc_constraint: An optional projection function to be applied to the
|
||||
loc after being updated by an `Optimizer`. The function must take as input
|
||||
the unprojected variable and must return the projected variable (which
|
||||
must have the same shape). Constraints are not safe to use when doing
|
||||
asynchronous distributed training.
|
||||
untransformed_scale_constraint: An optional projection function to be
|
||||
applied to the `scale` parameters after being updated by an `Optimizer`
|
||||
(e.g. used to implement norm constraints or value constraints). The
|
||||
function must take as input the unprojected variable and must return the
|
||||
projected variable (which must have the same shape). Constraints are not
|
||||
safe to use when doing asynchronous distributed training.
|
||||
|
||||
Returns:
|
||||
make_normal_fn: Python `callable` which creates a `tf.distributions.Normal`
|
||||
using from args: `dtype, shape, name, trainable, add_variable_fn`.
|
||||
"""
|
||||
loc_scale_fn_ = default_loc_scale_fn(
|
||||
is_singular,
|
||||
loc_initializer,
|
||||
untransformed_scale_initializer,
|
||||
loc_regularizer,
|
||||
untransformed_scale_regularizer,
|
||||
loc_constraint,
|
||||
untransformed_scale_constraint)
|
||||
def _fn(dtype, shape, name, trainable, add_variable_fn):
|
||||
"""Creates a batch of `Deterministic` or `Normal` distributions."""
|
||||
loc, scale = loc_scale_fn_(dtype, shape, name, trainable, add_variable_fn)
|
||||
if scale is None:
|
||||
return deterministic_lib.Deterministic(loc=loc)
|
||||
return normal_lib.Normal(loc=loc, scale=scale)
|
||||
return _fn
|
||||
|
||||
|
||||
class DenseVariational(layers_lib.Layer):
|
||||
"""Densely-connected variational class.
|
||||
|
||||
This layer implements the Bayesian variational inference analogue to:
|
||||
`outputs = activation(matmul(inputs, kernel) + bias)`
|
||||
by assuming the `kernel` and/or the `bias` are random variables.
|
||||
|
||||
The layer implements a stochastic dense calculation by making a Monte Carlo
|
||||
approximation of a [variational Bayesian method based on KL divergence](
|
||||
https://en.wikipedia.org/wiki/Variational_Bayesian_methods), i.e.,
|
||||
|
||||
```none
|
||||
-log p(y|x) = -log int_{R**d} p(y|x,w) p(w) dw
|
||||
= -log int_{R**d} p(y,w|x) q(w|x) / q(w|x) dw
|
||||
<= E_q(W|x)[-log p(y,W|x) + log q(W|x)] # Jensen's
|
||||
= E_q(W|x)[-log p(y|x,W)] + KL[q(W|x), p(W)]
|
||||
~= m**-1 sum{ -log(y|x,w[j]) : w[j] ~ q(W|x), j=1..m }
|
||||
+ KL[q(W|x), p(W)]
|
||||
```
|
||||
|
||||
where `W` denotes the (independent) `kernel` and `bias` random variables, `w`
|
||||
is a random variate or outcome of `W`, `y` is the label, `x` is the evidence`,
|
||||
and `~=` denotes an approximation which becomes exact as `m->inf`. The above
|
||||
bound is sometimes referred to as the negative Evidence Lower BOund or
|
||||
negative [ELBO](https://arxiv.org/abs/1601.00670). In context of a DNN, this
|
||||
layer is appropriate to use when the final loss is a negative log-likelihood.
|
||||
|
||||
The Monte-Carlo sum portion is used for the feed-forward calculation of the
|
||||
DNN. The KL divergence portion can be added to the final loss via:
|
||||
`loss += sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))`.
|
||||
|
||||
The arguments permit separate specification of the surrogate posterior
|
||||
(`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
|
||||
random variables (which together comprise `W`).
|
||||
|
||||
Args:
|
||||
units: Integer or Long, dimensionality of the output space.
|
||||
activation: Activation function (`callable`). Set it to None to maintain a
|
||||
linear activation.
|
||||
activity_regularizer: Regularizer function for the output.
|
||||
trainable: Boolean, if `True` also add variables to the graph collection
|
||||
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
|
||||
kernel_use_local_reparameterization: Python `bool` indicating whether
|
||||
`kernel` calculation should employ the Local Reparameterization Trick.
|
||||
When `True`, `kernel_posterior_fn` must create an instance of
|
||||
`tf.distributions.Normal`.
|
||||
kernel_posterior_fn: Python `callable` which creates
|
||||
`tf.distributions.Distribution` instance representing the surrogate
|
||||
posterior of the `kernel` parameter. Default value:
|
||||
`default_mean_field_normal_fn()`.
|
||||
kernel_posterior_tensor_fn: Python `callable` which takes a
|
||||
`tf.distributions.Distribution` instance and returns a representative
|
||||
value. Default value: `lambda d: d.sample()`.
|
||||
kernel_prior_fn: Python `callable` which creates `tf.distributions`
|
||||
instance. See `default_mean_field_normal_fn` docstring for required
|
||||
parameter signature.
|
||||
Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
|
||||
kernel_divergence_fn: Python `callable` which takes the surrogate posterior
|
||||
distribution, prior distribution and random variate sample(s) from the
|
||||
surrogate posterior and computes or approximates the KL divergence. The
|
||||
distributions are `tf.distributions.Distribution`-like instances and the
|
||||
sample is a `Tensor`.
|
||||
bias_posterior_fn: Python `callable` which creates
|
||||
`tf.distributions.Distribution` instance representing the surrogate
|
||||
posterior of the `bias` parameter. Default value:
|
||||
`default_mean_field_normal_fn(is_singular=True)` (which creates an
|
||||
instance of `tf.distributions.Deterministic`).
|
||||
bias_posterior_tensor_fn: Python `callable` which takes a
|
||||
`tf.distributions.Distribution` instance and returns a representative
|
||||
value. Default value: `lambda d: d.sample()`.
|
||||
bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
|
||||
See `default_mean_field_normal_fn` docstring for required parameter
|
||||
signature. Default value: `None` (no prior, no variational inference)
|
||||
bias_divergence_fn: Python `callable` which takes the surrogate posterior
|
||||
distribution, prior distribution and random variate sample(s) from the
|
||||
surrogate posterior and computes or approximates the KL divergence. The
|
||||
distributions are `tf.distributions.Distribution`-like instances and the
|
||||
sample is a `Tensor`.
|
||||
name: Python `str`, the name of the layer. Layers with the same name will
|
||||
share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in
|
||||
such cases.
|
||||
reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous
|
||||
layer by the same name.
|
||||
|
||||
Properties:
|
||||
units: Python integer, dimensionality of the output space.
|
||||
activation: Activation function (`callable`).
|
||||
activity_regularizer: Regularizer function for the output.
|
||||
kernel_use_local_reparameterization: Python `bool` indicating whether
|
||||
`kernel` calculation should employ the Local Reparameterization Trick.
|
||||
kernel: `VariationalKernelParamater` instance containing all `kernel`
|
||||
related properties and `callable`s.
|
||||
bias: `VariationalParameter` instance containing all `kernel`
|
||||
related properties and `callable`s.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
units,
|
||||
activation=None,
|
||||
activity_regularizer=None,
|
||||
trainable=True,
|
||||
kernel_use_local_reparameterization=True,
|
||||
kernel_posterior_fn=default_mean_field_normal_fn(),
|
||||
kernel_posterior_tensor_fn=lambda d: d.sample(),
|
||||
kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda
|
||||
loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)),
|
||||
kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
|
||||
bias_posterior_fn=default_mean_field_normal_fn(is_singular=True),
|
||||
bias_posterior_tensor_fn=lambda d: d.sample(),
|
||||
bias_prior_fn=None,
|
||||
bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
|
||||
name=None,
|
||||
**kwargs):
|
||||
super(DenseVariational, self).__init__(
|
||||
trainable=trainable,
|
||||
name=name,
|
||||
activity_regularizer=activity_regularizer,
|
||||
**kwargs)
|
||||
self._units = units
|
||||
self._activation = activation
|
||||
self._input_spec = layers_lib.InputSpec(min_ndim=2)
|
||||
self._kernel_use_local_reparameterization = (
|
||||
kernel_use_local_reparameterization)
|
||||
self._kernel = VariationalKernelParameter(
|
||||
kernel_posterior_fn,
|
||||
kernel_posterior_tensor_fn,
|
||||
kernel_prior_fn,
|
||||
kernel_divergence_fn)
|
||||
self._bias = VariationalParameter(
|
||||
bias_posterior_fn,
|
||||
bias_posterior_tensor_fn,
|
||||
bias_prior_fn,
|
||||
bias_divergence_fn)
|
||||
|
||||
@property
|
||||
def units(self):
|
||||
return self._units
|
||||
|
||||
@property
|
||||
def activation(self):
|
||||
return self._activation
|
||||
|
||||
@property
|
||||
def input_spec(self):
|
||||
return self._input_spec
|
||||
|
||||
@input_spec.setter
|
||||
def input_spec(self, value):
|
||||
self._input_spec = value
|
||||
|
||||
@property
|
||||
def kernel_use_local_reparameterization(self):
|
||||
return self._kernel_use_local_reparameterization
|
||||
|
||||
@property
|
||||
def kernel(self):
|
||||
return self._kernel
|
||||
|
||||
@property
|
||||
def bias(self):
|
||||
return self._bias
|
||||
|
||||
def build(self, input_shape):
|
||||
input_shape = tensor_shape.TensorShape(input_shape)
|
||||
in_size = input_shape.with_rank_at_least(2)[-1].value
|
||||
if in_size is None:
|
||||
raise ValueError("The last dimension of the inputs to `Dense` "
|
||||
"should be defined. Found `None`.")
|
||||
self._input_spec = layers_lib.InputSpec(min_ndim=2, axes={-1: in_size})
|
||||
dtype = dtypes.as_dtype(self.dtype)
|
||||
|
||||
# Must have a posterior kernel.
|
||||
self.kernel.posterior = self.kernel.posterior_fn(
|
||||
dtype, [in_size, self.units], "kernel_posterior",
|
||||
self.trainable, self.add_variable)
|
||||
|
||||
if self.kernel.prior_fn is None:
|
||||
self.kernel_prior = None
|
||||
else:
|
||||
self.kernel.prior = self.kernel.prior_fn(
|
||||
dtype, [in_size, self.units], "kernel_prior",
|
||||
self.trainable, self.add_variable)
|
||||
self._built_kernel_divergence = False
|
||||
|
||||
if self.bias.posterior_fn is None:
|
||||
self.bias.posterior = None
|
||||
else:
|
||||
self.bias.posterior = self.bias.posterior_fn(
|
||||
dtype, [self.units], "bias_posterior",
|
||||
self.trainable, self.add_variable)
|
||||
|
||||
if self.bias.prior_fn is None:
|
||||
self.bias.prior = None
|
||||
else:
|
||||
self.bias.prior = self.bias.prior_fn(
|
||||
dtype, [self.units], "bias_prior",
|
||||
self.trainable, self.add_variable)
|
||||
self._built_bias_divergence = False
|
||||
|
||||
self.built = True
|
||||
|
||||
def call(self, inputs):
|
||||
inputs = ops.convert_to_tensor(inputs, dtype=self.dtype)
|
||||
|
||||
outputs = self._apply_variational_kernel(inputs)
|
||||
outputs = self._apply_variational_bias(outputs)
|
||||
if self.activation is not None:
|
||||
outputs = self.activation(outputs) # pylint: disable=not-callable
|
||||
if not self._built_kernel_divergence:
|
||||
self._apply_divergence(self.kernel, name="divergence_kernel")
|
||||
self._built_kernel_divergence = True
|
||||
if not self._built_bias_divergence:
|
||||
self._apply_divergence(self.bias, name="divergence_bias")
|
||||
self._built_bias_divergence = True
|
||||
return outputs
|
||||
|
||||
def _apply_variational_kernel(self, inputs):
|
||||
if not self.kernel_use_local_reparameterization:
|
||||
self.kernel.posterior_tensor = self.kernel.posterior_tensor_fn(
|
||||
self.kernel.posterior)
|
||||
self.kernel.posterior_affine = None
|
||||
self.kernel.posterior_affine_tensor = None
|
||||
return self._matmul(inputs, self.kernel.posterior_tensor)
|
||||
if not isinstance(self.kernel.posterior, normal_lib.Normal):
|
||||
raise TypeError("`kernel_use_local_reparameterization=True` requires "
|
||||
"`kernel_posterior_fn` produce an instance of "
|
||||
"`tf.distributions.Normal` (saw: \"{}\").".format(
|
||||
type(self.kernel.posterior).__name__))
|
||||
self.kernel.posterior_affine = normal_lib.Normal(
|
||||
loc=self._matmul(inputs, self.kernel.posterior.loc),
|
||||
scale=standard_ops.sqrt(self._matmul(
|
||||
standard_ops.square(inputs),
|
||||
standard_ops.square(self.kernel.posterior.scale))))
|
||||
self.kernel.posterior_affine_tensor = (
|
||||
self.kernel.posterior_tensor_fn(self.kernel.posterior_affine))
|
||||
self.kernel.posterior_tensor = None
|
||||
return self.kernel.posterior_affine_tensor
|
||||
|
||||
def _apply_variational_bias(self, inputs):
|
||||
if self.bias.posterior is None:
|
||||
self.bias.posterior_tensor = None
|
||||
return inputs
|
||||
self.bias.posterior_tensor = self.bias.posterior_tensor_fn(
|
||||
self.bias.posterior)
|
||||
return nn.bias_add(inputs, self.bias.posterior_tensor)
|
||||
|
||||
def _apply_divergence(self, param, name):
|
||||
if (param.divergence_fn is None or
|
||||
param.posterior is None or
|
||||
param.prior is None):
|
||||
param.divergence = None
|
||||
return
|
||||
param.divergence = standard_ops.identity(
|
||||
param.divergence_fn(
|
||||
param.posterior, param.prior, param.posterior_tensor),
|
||||
name=name)
|
||||
self.add_loss(param.divergence)
|
||||
|
||||
def _matmul(self, inputs, kernel):
|
||||
if inputs.shape.ndims <= 2:
|
||||
return standard_ops.matmul(inputs, kernel)
|
||||
# To handle broadcasting, we must use `tensordot`.
|
||||
return standard_ops.tensordot(inputs, kernel, axes=[[-1], [0]])
|
||||
|
||||
def _compute_output_shape(self, input_shape):
|
||||
input_shape = tensor_shape.TensorShape(input_shape).with_rank_at_least(2)
|
||||
if input_shape[-1].value is None:
|
||||
raise ValueError(
|
||||
"The innermost dimension of input_shape must be defined, "
|
||||
"but saw: {}".format(input_shape))
|
||||
return input_shape[:-1].concatenate(self.units)
|
||||
|
||||
|
||||
def dense_variational(
|
||||
inputs,
|
||||
units,
|
||||
activation=None,
|
||||
activity_regularizer=None,
|
||||
trainable=True,
|
||||
kernel_use_local_reparameterization=True,
|
||||
kernel_posterior_fn=default_mean_field_normal_fn(),
|
||||
kernel_posterior_tensor_fn=lambda d: d.sample(),
|
||||
kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda
|
||||
loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)),
|
||||
kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
|
||||
bias_posterior_fn=default_mean_field_normal_fn(is_singular=True),
|
||||
bias_posterior_tensor_fn=lambda d: d.sample(),
|
||||
bias_prior_fn=None,
|
||||
bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
|
||||
name=None,
|
||||
reuse=None):
|
||||
"""Densely-connected variational layer.
|
||||
|
||||
This layer implements the Bayesian variational inference analogue to:
|
||||
`outputs = activation(matmul(inputs, kernel) + bias)`
|
||||
by assuming the `kernel` and/or the `bias` are random variables.
|
||||
|
||||
The layer implements a stochastic dense calculation by making a Monte Carlo
|
||||
approximation of a [variational Bayesian method based on KL divergence](
|
||||
https://en.wikipedia.org/wiki/Variational_Bayesian_methods), i.e.,
|
||||
|
||||
```none
|
||||
-log p(y|x) = -log int_{R**d} p(y|x,w) p(w) dw
|
||||
= -log int_{R**d} p(y,w|x) q(w|x) / q(w|x) dw
|
||||
<= E_q(W|x)[-log p(y,W|x) + log q(W|x)] # Jensen's
|
||||
= E_q(W|x)[-log p(y|x,W)] + KL[q(W|x), p(W)]
|
||||
~= m**-1 sum{ -log(y|x,w[j]) : w[j] ~ q(W|x), j=1..m }
|
||||
+ KL[q(W|x), p(W)]
|
||||
```
|
||||
|
||||
where `W` denotes the (independent) `kernel` and `bias` random variables, `w`
|
||||
is a random variate or outcome of `W`, `y` is the label, `x` is the evidence`,
|
||||
and `~=` denotes an approximation which becomes exact as `m->inf`. The above
|
||||
bound is sometimes referred to as the negative Evidence Lower BOund or
|
||||
negative [ELBO](https://arxiv.org/abs/1601.00670). In context of a DNN, this
|
||||
layer is appropriate to use when the final loss is a negative log-likelihood.
|
||||
|
||||
The Monte-Carlo sum portion is used for the feed-forward calculation of the
|
||||
DNN. The KL divergence portion can be added to the final loss via:
|
||||
`loss += sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))`.
|
||||
|
||||
The arguments permit separate specification of the surrogate posterior
|
||||
(`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
|
||||
random variables (which together comprise `W`).
|
||||
|
||||
Args:
|
||||
inputs: Tensor input.
|
||||
units: Integer or Long, dimensionality of the output space.
|
||||
activation: Activation function (`callable`). Set it to None to maintain a
|
||||
linear activation.
|
||||
activity_regularizer: Regularizer function for the output.
|
||||
trainable: Boolean, if `True` also add variables to the graph collection
|
||||
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
|
||||
kernel_use_local_reparameterization: Python `bool` indicating whether
|
||||
`kernel` calculation should employ the Local Reparameterization Trick.
|
||||
When `True`, `kernel_posterior_fn` must create an instance of
|
||||
`tf.distributions.Normal`.
|
||||
kernel_posterior_fn: Python `callable` which creates
|
||||
`tf.distributions.Distribution` instance representing the surrogate
|
||||
posterior of the `kernel` parameter. Default value:
|
||||
`default_mean_field_normal_fn()`.
|
||||
kernel_posterior_tensor_fn: Python `callable` which takes a
|
||||
`tf.distributions.Distribution` instance and returns a representative
|
||||
value. Default value: `lambda d: d.sample()`.
|
||||
kernel_prior_fn: Python `callable` which creates `tf.distributions`
|
||||
instance. See `default_mean_field_normal_fn` docstring for required
|
||||
parameter signature.
|
||||
Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
|
||||
kernel_divergence_fn: Python `callable` which takes the surrogate posterior
|
||||
distribution, prior distribution and random variate sample(s) from the
|
||||
surrogate posterior and computes or approximates the KL divergence. The
|
||||
distributions are `tf.distributions.Distribution`-like instances and the
|
||||
sample is a `Tensor`.
|
||||
bias_posterior_fn: Python `callable` which creates
|
||||
`tf.distributions.Distribution` instance representing the surrogate
|
||||
posterior of the `bias` parameter. Default value:
|
||||
`default_mean_field_normal_fn(is_singular=True)` (which creates an
|
||||
instance of `tf.distributions.Deterministic`).
|
||||
bias_posterior_tensor_fn: Python `callable` which takes a
|
||||
`tf.distributions.Distribution` instance and returns a representative
|
||||
value. Default value: `lambda d: d.sample()`.
|
||||
bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
|
||||
See `default_mean_field_normal_fn` docstring for required parameter
|
||||
signature. Default value: `None` (no prior, no variational inference)
|
||||
bias_divergence_fn: Python `callable` which takes the surrogate posterior
|
||||
distribution, prior distribution and random variate sample(s) from the
|
||||
surrogate posterior and computes or approximates the KL divergence. The
|
||||
distributions are `tf.distributions.Distribution`-like instances and the
|
||||
sample is a `Tensor`.
|
||||
name: Python `str`, the name of the layer. Layers with the same name will
|
||||
share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in
|
||||
such cases.
|
||||
reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous
|
||||
layer by the same name.
|
||||
|
||||
Returns:
|
||||
output: `Tensor` representing a the affine transformed input under a random
|
||||
draw from the surrogate posterior distribution.
|
||||
"""
|
||||
layer = DenseVariational(
|
||||
units,
|
||||
activation=activation,
|
||||
activity_regularizer=activity_regularizer,
|
||||
trainable=trainable,
|
||||
kernel_use_local_reparameterization=(
|
||||
kernel_use_local_reparameterization),
|
||||
kernel_posterior_fn=kernel_posterior_fn,
|
||||
kernel_posterior_tensor_fn=kernel_posterior_tensor_fn,
|
||||
kernel_prior_fn=kernel_prior_fn,
|
||||
kernel_divergence_fn=kernel_divergence_fn,
|
||||
bias_posterior_fn=bias_posterior_fn,
|
||||
bias_posterior_tensor_fn=bias_posterior_tensor_fn,
|
||||
bias_prior_fn=bias_prior_fn,
|
||||
bias_divergence_fn=bias_divergence_fn,
|
||||
name=name,
|
||||
dtype=inputs.dtype.base_dtype,
|
||||
_scope=name,
|
||||
_reuse=reuse)
|
||||
return layer.apply(inputs)
|
||||
|
||||
|
||||
class NotSet(object):
|
||||
"""Helper to track whether a `VariationalParameter` value has been set."""
|
||||
pass
|
||||
|
||||
|
||||
class VariationalParameter(object):
|
||||
"""Struct-like container of variational parameter properties.
|
||||
|
||||
A `VariationalParameter` is intitialized with Python `callable`s which set the
|
||||
value of correspondingly named members. Corresponding values have "set once"
|
||||
semantics, i.e., once set to any value they are immutable.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
posterior_fn,
|
||||
posterior_tensor_fn,
|
||||
prior_fn,
|
||||
divergence_fn):
|
||||
"""Creates the `VariationalParameter` struct-like object.
|
||||
|
||||
Args:
|
||||
posterior_fn: Python `callable` which creates a
|
||||
`tf.distribution.Distribution` like object representing the posterior
|
||||
distribution. See `VariationalParameter.posterior_fn` for `callable`'s
|
||||
required parameters.
|
||||
posterior_tensor_fn: Python `callable` which computes a `Tensor`
|
||||
which represents the `posterior`.
|
||||
prior_fn: Python `callable` which creates a
|
||||
`tf.distribution.Distribution` like object representing the prior
|
||||
distribution. See `VariationalParameter.prior_fn` for `callable`'s
|
||||
required parameters.
|
||||
divergence_fn: Python `callable` which computes the KL divergence from
|
||||
`posterior` to `prior`. See `VariationalParameter.divergence_fn` for
|
||||
required `callable`'s parameters.
|
||||
"""
|
||||
self._posterior_fn = posterior_fn
|
||||
self._posterior = NotSet()
|
||||
self._posterior_tensor_fn = posterior_tensor_fn
|
||||
self._posterior_tensor = NotSet()
|
||||
self._prior_fn = prior_fn
|
||||
self._prior = NotSet()
|
||||
self._divergence_fn = divergence_fn
|
||||
self._divergence = NotSet()
|
||||
self._init_helper()
|
||||
|
||||
@property
|
||||
def posterior_fn(self):
|
||||
"""`callable` which creates `tf.distributions.Distribution`-like posterior.
|
||||
|
||||
The `callable` must accept the following parameters:
|
||||
name: Python `str` name prepended to any created (or existing)
|
||||
`tf.Variable`s.
|
||||
shape: Python `list`-like representing the parameter's event shape.
|
||||
dtype: Type of parameter's event.
|
||||
trainable: Python `bool` indicating all created `tf.Variable`s should be
|
||||
added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`.
|
||||
add_variable_fn: `tf.get_variable`-like `callable` used to create (or
|
||||
access existing) `tf.Variable`s.
|
||||
|
||||
Returns:
|
||||
posterior_fn: The Python `callable` specified in `__init__`.
|
||||
"""
|
||||
return self._posterior_fn
|
||||
|
||||
@property
|
||||
def posterior(self):
|
||||
"""`tf.distributions.Distribution`-like instance representing posterior."""
|
||||
return self._posterior
|
||||
|
||||
@posterior.setter
|
||||
def posterior(self, value):
|
||||
"""One-time setter of the `posterior` distribution."""
|
||||
if not isinstance(self._posterior, NotSet):
|
||||
raise ValueError("Cannot override already set attribute.")
|
||||
self._posterior = value
|
||||
|
||||
@property
|
||||
def posterior_tensor_fn(self):
|
||||
"""Creates `Tensor` representing the `posterior` distribution.
|
||||
|
||||
The `callable` must accept the following parameters:
|
||||
posterior: `tf.distributions.Distribution`-like instance.
|
||||
|
||||
Returns:
|
||||
posterior_tensor_fn: The Python `callable` specified in
|
||||
`__init__`.
|
||||
"""
|
||||
return self._posterior_tensor_fn
|
||||
|
||||
@property
|
||||
def posterior_tensor(self):
|
||||
"""`Tensor` representing the `posterior` distribution."""
|
||||
return self._posterior_tensor
|
||||
|
||||
@posterior_tensor.setter
|
||||
def posterior_tensor(self, value):
|
||||
"""One-time setter of the `posterior_tensor`."""
|
||||
if not isinstance(self._posterior_tensor, NotSet):
|
||||
raise ValueError("Cannot override already set attribute.")
|
||||
self._posterior_tensor = value
|
||||
|
||||
@property
|
||||
def prior_fn(self):
|
||||
"""`callable` which creates `tf.distributions.Distribution`-like prior.
|
||||
|
||||
The `callable` must accept the following parameters:
|
||||
name: Python `str` name prepended to any created (or existing)
|
||||
`tf.Variable`s.
|
||||
shape: Python `list`-like representing the parameter's event shape.
|
||||
dtype: Type of parameter's event.
|
||||
trainable: Python `bool` indicating all created `tf.Variable`s should be
|
||||
added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`.
|
||||
add_variable_fn: `tf.get_variable`-like `callable` used to create (or
|
||||
access existing) `tf.Variable`s.
|
||||
|
||||
Returns:
|
||||
prior_fn: The Python `callable` specified in `__init__`.
|
||||
"""
|
||||
return self._prior_fn
|
||||
|
||||
@property
|
||||
def prior(self):
|
||||
"""`tf.distributions.Distribution`-like instance representing posterior."""
|
||||
return self._prior
|
||||
|
||||
@prior.setter
|
||||
def prior(self, value):
|
||||
"""One-time setter of the `prior` distribution."""
|
||||
if not isinstance(self._prior, NotSet):
|
||||
raise ValueError("Cannot override already set attribute.")
|
||||
self._prior = value
|
||||
|
||||
@property
|
||||
def divergence_fn(self):
|
||||
"""`callable` which computes KL-divergence `Tensor` from posterior to prior.
|
||||
|
||||
The `callable` must accept the following parameters:
|
||||
posterior: `tf.distributions.Distribution`-like instance.
|
||||
prior: `tf.distributions.Distribution`-like instance.
|
||||
posterior_tensor: `Tensor` representing value of posterior.
|
||||
|
||||
Returns:
|
||||
divergence_fn: The Python `callable` specified in `__init__`.
|
||||
"""
|
||||
return self._divergence_fn
|
||||
|
||||
@property
|
||||
def divergence(self):
|
||||
"""`Tensor` representing KL-divergence from posterior to prior."""
|
||||
return self._divergence
|
||||
|
||||
@divergence.setter
|
||||
def divergence(self, value):
|
||||
"""One-time setter of the `divergence`."""
|
||||
if not isinstance(self._divergence, NotSet):
|
||||
raise ValueError("Cannot override already set attribute.")
|
||||
self._divergence = value
|
||||
|
||||
def _init_helper(self):
|
||||
pass
|
||||
|
||||
|
||||
class VariationalKernelParameter(VariationalParameter):
|
||||
"""Struct-like container of variational kernel properties.
|
||||
|
||||
A `VariationalKernelParameter` is intitialized with Python `callable`s which
|
||||
set the value of correspondingly named members. Corresponding values have "set
|
||||
once" semantics, i.e., once set to any value they are immutable.
|
||||
"""
|
||||
|
||||
@property
|
||||
def posterior_affine(self):
|
||||
"""`tf.distributions.Distribution` affine transformed posterior."""
|
||||
return self._posterior_affine
|
||||
|
||||
@posterior_affine.setter
|
||||
def posterior_affine(self, value):
|
||||
"""One-time setter of `posterior_affine`."""
|
||||
if not isinstance(self._posterior_affine, NotSet):
|
||||
raise ValueError("Cannot override already set attribute.")
|
||||
self._posterior_affine = value
|
||||
|
||||
@property
|
||||
def posterior_affine_tensor(self):
|
||||
"""`Tensor` representing the `posterior_affine` distribution."""
|
||||
return self._posterior_affine_tensor
|
||||
|
||||
@posterior_affine_tensor.setter
|
||||
def posterior_affine_tensor(self, value):
|
||||
"""One-time setter of the `posterior_affine_tensor`."""
|
||||
if not isinstance(self._posterior_affine_tensor, NotSet):
|
||||
raise ValueError("Cannot override already set attribute.")
|
||||
self._posterior_affine_tensor = value
|
||||
|
||||
def _init_helper(self):
|
||||
self._posterior_affine = NotSet()
|
||||
self._posterior_affine_tensor = NotSet()
|
34
tensorflow/contrib/bayesflow/python/ops/optimizers.py
Normal file
34
tensorflow/contrib/bayesflow/python/ops/optimizers.py
Normal file
@ -0,0 +1,34 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Probabilistic optimizer modules.
|
||||
|
||||
See ${python/contrib.bayesflow.optimizers}.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# go/tf-wildcard-import
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.contrib.bayesflow.python.ops.sgld_optimizer import *
|
||||
# pylint: enable=wildcard-import
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
_allowed_symbols = [
|
||||
'SGLDOptimizer',
|
||||
]
|
||||
|
||||
remove_undocumented(__name__, _allowed_symbols)
|
216
tensorflow/contrib/bayesflow/python/ops/sgld_optimizer.py
Normal file
216
tensorflow/contrib/bayesflow/python/ops/sgld_optimizer.py
Normal file
@ -0,0 +1,216 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""An optimizer module for stochastic gradient Langevin dynamics."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import variable_scope as varscope_ops
|
||||
from tensorflow.python.training import optimizer
|
||||
from tensorflow.python.training import training_ops
|
||||
|
||||
|
||||
class SGLDOptimizer(optimizer.Optimizer):
|
||||
"""An optimizer module for stochastic gradient Langevin dynamics.
|
||||
|
||||
This implements the preconditioned Stochastic Gradient Langevin Dynamics
|
||||
optimizer [1]. The optimization variable is regarded as a sample from the
|
||||
posterior under Stochastic Gradient Langevin Dynamics with noise rescaled in
|
||||
each dimension according to RMSProp [2].
|
||||
|
||||
Note: If a prior is included in the loss, it should be scaled by
|
||||
`1/num_pseudo_batches`, where num_pseudo_batches is the number of minibatches
|
||||
in the data. I.e., it should be divided by the `num_pseudo_batches` term
|
||||
described below.
|
||||
|
||||
[1]: "Preconditioned Stochastic Gradient Langevin Dynamics for Deep Neural
|
||||
Networks." Chunyuan Li, Changyou Chen, David Carlson, Lawrence Carin.
|
||||
ArXiv:1512.07666, 2015. https://arxiv.org/abs/1512.07666
|
||||
[2]: http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf
|
||||
|
||||
Args:
|
||||
learning_rate: Scalar `float`-like `Tensor`. The base learning rate for the
|
||||
optimizer. Must be tuned to the specific function being minimized.
|
||||
preconditioner_decay_rate: Scalar `float`-like `Tensor`. The exponential
|
||||
decay rate of the rescaling of the preconditioner (RMSprop). (This is
|
||||
"alpha" in [1]). Should be smaller than but nearly `1` to approximate
|
||||
sampling from the posterior. (Default: `0.95`)
|
||||
num_pseudo_batches: Scalar `int`-like `Tensor`. The effective number of
|
||||
minibatches in the data set. Trades off noise and prior with the SGD
|
||||
likelihood term. Note: Assumes the loss is taken as the mean over a
|
||||
minibatch. Otherwise if the sum was taken, divide this number by the
|
||||
batch size. (Default: `1`)
|
||||
burnin: Scalar `int`-like `Tensor`. The number of iterations to collect
|
||||
gradient statistics to update the preconditioner before starting to draw
|
||||
noisy samples. (Default: `25`)
|
||||
diagonal_bias: Scalar `float`-like `Tensor`. Term added to the diagonal of
|
||||
the preconditioner to prevent the preconditioner from degenerating.
|
||||
(Default: `1e-8`)
|
||||
name: Python `str` describing ops managed by this function.
|
||||
(Default: `"SGLDOptimizer"`)
|
||||
variable_scope: Variable scope used for calls to `tf.get_variable`.
|
||||
If `None`, a new variable scope is created using name
|
||||
`ops.get_default_graph().unique_name(name or default_name)`.
|
||||
|
||||
Raises:
|
||||
InvalidArgumentError: If preconditioner_decay_rate is a `Tensor` not in
|
||||
`(0,1]`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
learning_rate,
|
||||
preconditioner_decay_rate=0.95,
|
||||
num_pseudo_batches=1,
|
||||
burnin=25,
|
||||
diagonal_bias=1e-8,
|
||||
name=None,
|
||||
variable_scope=None):
|
||||
default_name = 'SGLDOptimizer'
|
||||
with ops.name_scope(name, default_name, [
|
||||
learning_rate, preconditioner_decay_rate, num_pseudo_batches, burnin,
|
||||
diagonal_bias
|
||||
]):
|
||||
if variable_scope is None:
|
||||
var_scope_name = ops.get_default_graph().unique_name(
|
||||
name or default_name)
|
||||
with varscope_ops.variable_scope(var_scope_name) as scope:
|
||||
self._variable_scope = scope
|
||||
else:
|
||||
self._variable_scope = variable_scope
|
||||
|
||||
self._preconditioner_decay_rate = ops.convert_to_tensor(
|
||||
preconditioner_decay_rate, name='preconditioner_decay_rate')
|
||||
self._num_pseudo_batches = ops.convert_to_tensor(
|
||||
num_pseudo_batches, name='num_pseudo_batches')
|
||||
self._burnin = ops.convert_to_tensor(burnin, name='burnin')
|
||||
self._diagonal_bias = ops.convert_to_tensor(
|
||||
diagonal_bias, name='diagonal_bias')
|
||||
self._learning_rate = ops.convert_to_tensor(
|
||||
learning_rate, name='learning_rate')
|
||||
|
||||
with varscope_ops.variable_scope(self._variable_scope):
|
||||
self._counter = varscope_ops.get_variable(
|
||||
'counter', initializer=0, trainable=False)
|
||||
|
||||
self._preconditioner_decay_rate = control_flow_ops.with_dependencies([
|
||||
check_ops.assert_non_negative(
|
||||
self._preconditioner_decay_rate,
|
||||
message='`preconditioner_decay_rate` must be non-negative'),
|
||||
check_ops.assert_less_equal(
|
||||
self._preconditioner_decay_rate,
|
||||
1.,
|
||||
message='`preconditioner_decay_rate` must be at most 1.'),
|
||||
], self._preconditioner_decay_rate)
|
||||
|
||||
self._num_pseudo_batches = control_flow_ops.with_dependencies([
|
||||
check_ops.assert_greater(
|
||||
self._num_pseudo_batches,
|
||||
0,
|
||||
message='`num_pseudo_batches` must be greater than zero')
|
||||
], self._num_pseudo_batches)
|
||||
|
||||
self._burnin = control_flow_ops.with_dependencies([
|
||||
check_ops.assert_non_negative(
|
||||
self._burnin, message='`burnin` must be non-negative'),
|
||||
check_ops.assert_integer(
|
||||
self._burnin, message='`burnin` must be an integer')
|
||||
], self._burnin)
|
||||
|
||||
self._diagonal_bias = control_flow_ops.with_dependencies([
|
||||
check_ops.assert_non_negative(
|
||||
self._diagonal_bias,
|
||||
message='`diagonal_bias` must be non-negative')
|
||||
], self._diagonal_bias)
|
||||
|
||||
super(SGLDOptimizer, self).__init__(use_locking=False,
|
||||
name=name or default_name)
|
||||
|
||||
def _create_slots(self, var_list):
|
||||
for v in var_list:
|
||||
init_rms = init_ops.ones_initializer(dtype=v.dtype)
|
||||
self._get_or_make_slot_with_initializer(v, init_rms, v.get_shape(),
|
||||
v.dtype, 'rms', self._name)
|
||||
|
||||
def _prepare(self):
|
||||
# We need to put the conversion and check here because a user will likely
|
||||
# want to decay the learning rate dynamically.
|
||||
self._learning_rate_tensor = control_flow_ops.with_dependencies([
|
||||
check_ops.assert_non_negative(
|
||||
self._learning_rate, message='`learning_rate` must be non-negative')
|
||||
], ops.convert_to_tensor(self._learning_rate, name='learning_rate_tensor'))
|
||||
self._decay_tensor = ops.convert_to_tensor(
|
||||
self._preconditioner_decay_rate, name='preconditioner_decay_rate')
|
||||
|
||||
super(SGLDOptimizer, self)._prepare()
|
||||
|
||||
def _apply_dense(self, grad, var):
|
||||
rms = self.get_slot(var, 'rms')
|
||||
|
||||
with ops.control_dependencies([
|
||||
self._update_momentum(rms, grad, math_ops.cast(self._decay_tensor,
|
||||
var.dtype.base_dtype))]):
|
||||
new_grad = self._apply_noisy_update(rms, grad)
|
||||
|
||||
return training_ops.apply_gradient_descent(
|
||||
var,
|
||||
math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
|
||||
new_grad,
|
||||
use_locking=self._use_locking).op
|
||||
|
||||
def _apply_sparse(self, grad, var):
|
||||
rms = self.get_slot(var, 'rms')
|
||||
|
||||
with ops.control_dependencies([
|
||||
self._update_momentum(rms, grad, math_ops.cast(self._decay_tensor,
|
||||
var.dtype.base_dtype))]):
|
||||
new_grad = self._apply_noisy_update(rms, grad)
|
||||
|
||||
return training_ops.apply_gradient_descent(
|
||||
var,
|
||||
math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
|
||||
new_grad,
|
||||
use_locking=self._use_locking).op
|
||||
|
||||
@property
|
||||
def variable_scope(self):
|
||||
"""Variable scope of all calls to `tf.get_variable`."""
|
||||
return self._variable_scope
|
||||
|
||||
def _apply_noisy_update(self, mom, grad):
|
||||
# Compute and apply the gradient update following
|
||||
# preconditioned Langevin dynamics
|
||||
stddev = array_ops.where(
|
||||
array_ops.squeeze(self._counter > self._burnin),
|
||||
math_ops.cast(math_ops.rsqrt(self._learning_rate), grad.dtype),
|
||||
array_ops.zeros([], grad.dtype))
|
||||
|
||||
preconditioner = math_ops.rsqrt(
|
||||
mom + math_ops.cast(self._diagonal_bias, grad.dtype))
|
||||
return (
|
||||
0.5 * preconditioner * grad * math_ops.cast(self._num_pseudo_batches,
|
||||
grad.dtype) +
|
||||
random_ops.random_normal(array_ops.shape(grad), 1.0, dtype=grad.dtype) *
|
||||
stddev * math_ops.sqrt(preconditioner))
|
||||
|
||||
def _update_momentum(self, mom, grad, decay):
|
||||
# Keep an exponentially weighted moving average of squared gradients.
|
||||
# Not thread safe
|
||||
return mom.assign_add((1.0 - decay) * (math_ops.square(grad) - mom))
|
@ -63,19 +63,26 @@ const char* kPredictionsTensorName = "predictions";
|
||||
void CalculateTreesToInclude(
|
||||
const boosted_trees::trees::DecisionTreeEnsembleConfig& config,
|
||||
const std::vector<int32>& trees_to_drop, const int32 num_trees,
|
||||
const bool only_finalized, std::vector<int32>* trees_to_include) {
|
||||
const bool only_finalized, const bool center_bias,
|
||||
std::vector<int32>* trees_to_include) {
|
||||
trees_to_include->reserve(num_trees - trees_to_drop.size());
|
||||
|
||||
int32 index = 0;
|
||||
// This assumes that trees_to_drop is a sorted list of tree ids.
|
||||
for (int32 tree = 0; tree < num_trees; ++tree) {
|
||||
if ((!trees_to_drop.empty() && index < trees_to_drop.size() &&
|
||||
trees_to_drop[index] == tree) ||
|
||||
(only_finalized && config.tree_metadata_size() > 0 &&
|
||||
!config.tree_metadata(tree).is_finalized())) {
|
||||
// Skip the tree if tree is in the list of trees_to_drop.
|
||||
if (!trees_to_drop.empty() && index < trees_to_drop.size() &&
|
||||
trees_to_drop[index] == tree) {
|
||||
++index;
|
||||
continue;
|
||||
}
|
||||
// Or skip if the tree is not finalized and only_finalized is set,
|
||||
// with the exception of centering bias.
|
||||
if (only_finalized && !(center_bias && tree == 0) &&
|
||||
config.tree_metadata_size() > 0 &&
|
||||
!config.tree_metadata(tree).is_finalized()) {
|
||||
continue;
|
||||
}
|
||||
trees_to_include->push_back(tree);
|
||||
}
|
||||
}
|
||||
@ -250,7 +257,7 @@ class GradientTreesPredictionOp : public OpKernel {
|
||||
CalculateTreesToInclude(
|
||||
ensemble_resource->decision_tree_ensemble(), dropped_trees,
|
||||
ensemble_resource->decision_tree_ensemble().trees_size(),
|
||||
only_finalized_trees_, &trees_to_include);
|
||||
only_finalized_trees_, center_bias_, &trees_to_include);
|
||||
|
||||
// Allocate output predictions matrix.
|
||||
Tensor* output_predictions_t = nullptr;
|
||||
|
@ -237,6 +237,7 @@ class CenterTreeEnsembleBiasOp : public OpKernel {
|
||||
VLOG(1) << "Continuing to center bias, delta=" << total_delta;
|
||||
} else {
|
||||
VLOG(1) << "Done centering bias, delta=" << total_delta;
|
||||
ensemble_resource->LastTreeMetadata()->set_is_finalized(true);
|
||||
}
|
||||
Tensor* continue_centering_t = nullptr;
|
||||
OP_REQUIRES_OK(
|
||||
@ -260,7 +261,6 @@ class CenterTreeEnsembleBiasOp : public OpKernel {
|
||||
for (size_t idx = 0; idx < logits_dimension; ++idx) {
|
||||
leaf->mutable_vector()->add_value(0.0);
|
||||
}
|
||||
ensemble_resource->LastTreeMetadata()->set_is_finalized(true);
|
||||
return leaf;
|
||||
} else if (num_trees == 1) {
|
||||
// Confirms that the only tree is a bias and returns its leaf.
|
||||
|
@ -181,7 +181,6 @@ class CenterTreeEnsembleBiasOpTest(test_util.TensorFlowTestCase):
|
||||
tree_weights: 1.0
|
||||
tree_metadata {
|
||||
num_layers_grown: 1
|
||||
is_finalized: true
|
||||
}
|
||||
growing_metadata {
|
||||
num_trees_attempted: 1
|
||||
@ -189,7 +188,7 @@ class CenterTreeEnsembleBiasOpTest(test_util.TensorFlowTestCase):
|
||||
}
|
||||
"""
|
||||
self.assertEqual(new_stamp, 1)
|
||||
self.assertEqual(stats.num_trees, 1)
|
||||
self.assertEqual(stats.num_trees, 0)
|
||||
self.assertEqual(stats.num_layers, 1)
|
||||
self.assertEqual(stats.active_tree, 1)
|
||||
self.assertEqual(stats.active_layer, 1)
|
||||
@ -231,7 +230,6 @@ class CenterTreeEnsembleBiasOpTest(test_util.TensorFlowTestCase):
|
||||
tree_weights: 1.0
|
||||
tree_metadata {
|
||||
num_layers_grown: 1
|
||||
is_finalized: true
|
||||
}
|
||||
growing_metadata {
|
||||
num_trees_attempted: 1
|
||||
@ -239,7 +237,7 @@ class CenterTreeEnsembleBiasOpTest(test_util.TensorFlowTestCase):
|
||||
}
|
||||
"""
|
||||
self.assertEqual(new_stamp, 2)
|
||||
self.assertEqual(stats.num_trees, 1)
|
||||
self.assertEqual(stats.num_trees, 0)
|
||||
self.assertEqual(stats.num_layers, 1)
|
||||
self.assertEqual(stats.active_tree, 1)
|
||||
self.assertEqual(stats.active_layer, 1)
|
||||
|
@ -238,6 +238,7 @@ add_python_module("tensorflow/python/keras/datasets")
|
||||
add_python_module("tensorflow/python/keras/datasets/boston_housing")
|
||||
add_python_module("tensorflow/python/keras/datasets/cifar10")
|
||||
add_python_module("tensorflow/python/keras/datasets/cifar100")
|
||||
add_python_module("tensorflow/python/keras/datasets/fashion_mnist")
|
||||
add_python_module("tensorflow/python/keras/datasets/imdb")
|
||||
add_python_module("tensorflow/python/keras/datasets/mnist")
|
||||
add_python_module("tensorflow/python/keras/datasets/reuters")
|
||||
|
@ -23,7 +23,6 @@ import itertools
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.crf.python.ops import crf
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -58,18 +57,19 @@ class CrfTest(test.TestCase):
|
||||
def testCrfUnaryScore(self):
|
||||
inputs = np.array(
|
||||
[[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32)
|
||||
tag_indices = np.array([1, 2, 1, 0], dtype=np.int32)
|
||||
sequence_lengths = np.array(3, dtype=np.int32)
|
||||
with self.test_session() as sess:
|
||||
unary_score = crf.crf_unary_score(
|
||||
tag_indices=array_ops.expand_dims(tag_indices, 0),
|
||||
sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
|
||||
inputs=array_ops.expand_dims(inputs, 0))
|
||||
unary_score = array_ops.squeeze(unary_score, [0])
|
||||
tf_unary_score = sess.run(unary_score)
|
||||
expected_unary_score = sum(inputs[i][tag_indices[i]]
|
||||
for i in range(sequence_lengths))
|
||||
self.assertAllClose(tf_unary_score, expected_unary_score)
|
||||
for dtype in (np.int32, np.int64):
|
||||
tag_indices = np.array([1, 2, 1, 0], dtype=dtype)
|
||||
sequence_lengths = np.array(3, dtype=np.int32)
|
||||
with self.test_session() as sess:
|
||||
unary_score = crf.crf_unary_score(
|
||||
tag_indices=array_ops.expand_dims(tag_indices, 0),
|
||||
sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
|
||||
inputs=array_ops.expand_dims(inputs, 0))
|
||||
unary_score = array_ops.squeeze(unary_score, [0])
|
||||
tf_unary_score = sess.run(unary_score)
|
||||
expected_unary_score = sum(inputs[i][tag_indices[i]]
|
||||
for i in range(sequence_lengths))
|
||||
self.assertAllClose(tf_unary_score, expected_unary_score)
|
||||
|
||||
def testCrfBinaryScore(self):
|
||||
tag_indices = np.array([1, 2, 1, 0], dtype=np.int32)
|
||||
|
@ -193,6 +193,9 @@ def crf_unary_score(tag_indices, sequence_lengths, inputs):
|
||||
offsets = array_ops.expand_dims(
|
||||
math_ops.range(batch_size) * max_seq_len * num_tags, 1)
|
||||
offsets += array_ops.expand_dims(math_ops.range(max_seq_len) * num_tags, 0)
|
||||
# Use int32 or int64 based on tag_indices' dtype.
|
||||
if tag_indices.dtype == dtypes.int64:
|
||||
offsets = math_ops.to_int64(offsets)
|
||||
flattened_tag_indices = array_ops.reshape(offsets + tag_indices, [-1])
|
||||
|
||||
unary_scores = array_ops.reshape(
|
||||
@ -305,7 +308,7 @@ def viterbi_decode(score, transition_params):
|
||||
|
||||
Returns:
|
||||
viterbi: A [seq_len] list of integers containing the highest scoring tag
|
||||
indicies.
|
||||
indices.
|
||||
viterbi_score: A float containing the score for the Viterbi sequence.
|
||||
"""
|
||||
trellis = np.zeros_like(score)
|
||||
@ -385,7 +388,7 @@ class CrfDecodeBackwardRnnCell(rnn_cell.RNNCell):
|
||||
"""Initialize the CrfDecodeBackwardRnnCell.
|
||||
|
||||
Args:
|
||||
num_tags: An integer.
|
||||
num_tags: An integer. The number of tags.
|
||||
"""
|
||||
self._num_tags = num_tags
|
||||
|
||||
@ -434,9 +437,9 @@ def crf_decode(potentials, transition_params, sequence_length):
|
||||
sequence_length: A [batch_size] vector of true sequence lengths.
|
||||
|
||||
Returns:
|
||||
decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`.
|
||||
Contains the highest scoring tag indicies.
|
||||
best_score: A [batch_size] vector, containing the score of `decode_tags`.
|
||||
decode_tags: A [batch_size, max_seq_len] tensor, with dtype tf.int32.
|
||||
Contains the highest scoring tag indices.
|
||||
best_score: A [batch_size] tensor, containing the score of decode_tags.
|
||||
"""
|
||||
# For simplicity, in shape comments, denote:
|
||||
# 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output).
|
||||
|
@ -270,6 +270,7 @@ py_test(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/data/python/ops:dataset_ops",
|
||||
"//tensorflow/contrib/data/python/ops:iterator_ops",
|
||||
"//tensorflow/contrib/data/python/ops:transformation_ops",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
@ -277,15 +278,20 @@ py_test(
|
||||
"//tensorflow/python:data_flow_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:functional_ops",
|
||||
"//tensorflow/python:io_ops",
|
||||
"//tensorflow/python:lookup_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:random_ops",
|
||||
"//tensorflow/python:script_ops",
|
||||
"//tensorflow/python:string_ops",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/data/ops:iterator_ops",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
@ -25,9 +25,13 @@ import numpy as np
|
||||
|
||||
from tensorflow.contrib.data.python.ops import dataset_ops
|
||||
from tensorflow.contrib.data.python.ops import error_ops
|
||||
from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import data_flow_ops
|
||||
@ -40,7 +44,10 @@ from tensorflow.python.ops import script_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
|
||||
@ -668,6 +675,515 @@ class MapDatasetTest(test.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testCaptureResourceInMapFn(self):
|
||||
|
||||
def _build_ds(iterator):
|
||||
|
||||
def _map_fn(x):
|
||||
get_next = iterator.get_next()
|
||||
return x * get_next
|
||||
|
||||
return dataset_ops.Dataset.range(10).map(_map_fn)
|
||||
|
||||
def _build_graph():
|
||||
captured_iterator = dataset_ops.Dataset.range(
|
||||
10).make_initializable_iterator()
|
||||
ds = _build_ds(captured_iterator)
|
||||
iterator = ds.make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
return captured_iterator.initializer, init_op
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
captured_init_op, init_op = _build_graph()
|
||||
with self.test_session(graph=g) as sess:
|
||||
sess.run(captured_init_op)
|
||||
with self.assertRaises(errors.UnimplementedError):
|
||||
# CapturedFunction does not support capturing IteratorResource.
|
||||
sess.run(init_op)
|
||||
|
||||
|
||||
class MapDatasetSerializationTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self._tensor_slice_len = 7
|
||||
self._num_epochs = 14
|
||||
self._num_outputs = self._tensor_slice_len * self._num_epochs
|
||||
|
||||
def tearDown(self):
|
||||
# Remove all checkpoint files.
|
||||
prefix = self._ckpt_path()
|
||||
pattern = prefix + "*"
|
||||
files = gfile.Glob(pattern)
|
||||
map(gfile.Remove, files)
|
||||
|
||||
def _build_ds(self, multiplier=37.0):
|
||||
components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) *
|
||||
np.arange(self._tensor_slice_len)[:, np.newaxis],
|
||||
np.array(multiplier) * np.arange(self._tensor_slice_len))
|
||||
|
||||
def _map_fn(x, y, z):
|
||||
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
|
||||
|
||||
return (dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn)
|
||||
.repeat(self._num_epochs))
|
||||
|
||||
def _build_graph(self, multiplier=37.0, build_saveable=True):
|
||||
ds = self._build_ds(multiplier)
|
||||
iterator = ds.make_initializable_iterator()
|
||||
|
||||
if build_saveable:
|
||||
saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
|
||||
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
self._add_iterator_ops_to_collection(init_op, get_next)
|
||||
saver = saver_lib.Saver(allow_empty=True)
|
||||
return init_op, get_next, saver
|
||||
|
||||
def _build_empty_graph(self, output_types, output_shapes):
|
||||
iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes)
|
||||
saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
|
||||
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
|
||||
saver = saver_lib.Saver()
|
||||
get_next = iterator.get_next()
|
||||
return get_next, saver
|
||||
|
||||
def _add_iterator_ops_to_collection(self, init_op, get_next):
|
||||
ops.add_to_collection("iterator_ops", init_op)
|
||||
ops.add_to_collection("iterator_ops", get_next[0])
|
||||
ops.add_to_collection("iterator_ops", get_next[1])
|
||||
ops.add_to_collection("iterator_ops", get_next[2])
|
||||
|
||||
def _get_iterator_ops_from_collection(self):
|
||||
init_op, get_next_1, get_next_2, get_next_3 = ops.get_collection(
|
||||
"iterator_ops")
|
||||
return init_op, (get_next_1, get_next_2, get_next_3)
|
||||
|
||||
def _ckpt_path(self):
|
||||
return os.path.join(self.get_temp_dir(), "iterator")
|
||||
|
||||
def _latest_ckpt(self):
|
||||
return saver_lib.latest_checkpoint(self.get_temp_dir())
|
||||
|
||||
def _save(self, sess, saver):
|
||||
saver.save(sess, self._ckpt_path())
|
||||
|
||||
def _restore(self, saver, sess):
|
||||
saver.restore(sess, self._latest_ckpt())
|
||||
|
||||
def _import_meta_graph(self):
|
||||
meta_file_path = self._ckpt_path() + ".meta"
|
||||
return saver_lib.import_meta_graph(meta_file_path)
|
||||
|
||||
def _testReadWithBreaks(self, break_points, init_before_restore=False):
|
||||
expected = []
|
||||
actual = []
|
||||
# Generate the ground truth.
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next_op, _ = self._build_graph()
|
||||
with self.test_session(graph=g) as sess:
|
||||
sess.run(init_op)
|
||||
for _ in range(self._num_outputs):
|
||||
expected.append(sess.run(get_next_op))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next_op)
|
||||
|
||||
# Run and checkpoint after first break_point.
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next_op, saver = self._build_graph()
|
||||
with self.test_session(graph=g) as sess:
|
||||
sess.run(init_op)
|
||||
for _ in range(break_points[0]):
|
||||
actual.append(sess.run(get_next_op))
|
||||
self._save(sess, saver)
|
||||
|
||||
# Load from checkpoint and continue running while stopping at each
|
||||
# subsequent checkpoint.
|
||||
for i in range(len(break_points)):
|
||||
with ops.Graph().as_default() as g:
|
||||
saver = self._import_meta_graph()
|
||||
init_op, get_next_op = self._get_iterator_ops_from_collection()
|
||||
with self.test_session(graph=g) as sess:
|
||||
if init_before_restore:
|
||||
sess.run(init_op)
|
||||
self._restore(saver, sess)
|
||||
start = break_points[i]
|
||||
end = break_points[
|
||||
i + 1] if i < len(break_points) - 1 else self._num_outputs
|
||||
for _ in range(end - start):
|
||||
actual.append(sess.run(get_next_op))
|
||||
self._save(sess, saver)
|
||||
if end == self._num_outputs:
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next_op)
|
||||
self._match(expected, actual)
|
||||
|
||||
def _match(self, expected, actual):
|
||||
self.assertEqual(len(expected), len(actual))
|
||||
for expected_tuple, actual_tuple in zip(expected, actual):
|
||||
self.assertEqual(expected_tuple[0], actual_tuple[0])
|
||||
self.assertSequenceEqual(expected_tuple[1].tolist(),
|
||||
actual_tuple[1].tolist())
|
||||
self.assertEqual(expected_tuple[2], actual_tuple[2])
|
||||
|
||||
def _does_not_match(self, expected, actual):
|
||||
with self.assertRaises(AssertionError):
|
||||
self._match(expected, actual)
|
||||
|
||||
def testSaveRestore(self):
|
||||
self._testReadWithBreaks([4])
|
||||
self._testReadWithBreaks([13])
|
||||
self._testReadWithBreaks([18])
|
||||
self._testReadWithBreaks([23])
|
||||
|
||||
def testSaveUnusedIterator(self):
|
||||
self._testReadWithBreaks([0])
|
||||
|
||||
def testSaveFullyUsedIterator(self):
|
||||
self._testReadWithBreaks([self._num_outputs])
|
||||
|
||||
def testMultipleBreaks(self):
|
||||
self._testReadWithBreaks([0, 5, 9, 15, 25, 32])
|
||||
|
||||
def testIdempotence(self):
|
||||
# Attempt to save iterator immediately after restoring.
|
||||
self._testReadWithBreaks([1, 1, 5, 5, 5, 25, 32])
|
||||
|
||||
def testInitThenRestore(self):
|
||||
self._testReadWithBreaks([0, 5, 9, 15, 25, 32], init_before_restore=True)
|
||||
|
||||
def testRestoreExhaustedIterator(self):
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next_op, saver = self._build_graph()
|
||||
with self.test_session(graph=g) as sess:
|
||||
sess.run(init_op)
|
||||
for _ in range(self._num_outputs):
|
||||
sess.run(get_next_op)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next_op)
|
||||
self._save(sess, saver)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
saver = self._import_meta_graph()
|
||||
init_op, get_next_op = self._get_iterator_ops_from_collection()
|
||||
with self.test_session(graph=g) as sess:
|
||||
self._restore(saver, sess)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next_op)
|
||||
|
||||
def testResetRestoredIterator(self):
|
||||
expected = []
|
||||
# Collect ground truth containing all outputs.
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next_op, saver = self._build_graph()
|
||||
break_point = self._num_outputs // 2
|
||||
with self.test_session(graph=g) as sess:
|
||||
sess.run(init_op)
|
||||
for _ in range(break_point):
|
||||
expected.append(sess.run(get_next_op))
|
||||
self._save(sess, saver)
|
||||
for _ in range(self._num_outputs - break_point):
|
||||
expected.append(sess.run(get_next_op))
|
||||
|
||||
actual = []
|
||||
# Restore from checkpoint and then run init_op.
|
||||
with ops.Graph().as_default() as g:
|
||||
saver = self._import_meta_graph()
|
||||
init_op, get_next_op = self._get_iterator_ops_from_collection()
|
||||
with self.test_session(graph=g) as sess:
|
||||
self._restore(saver, sess)
|
||||
sess.run(init_op)
|
||||
for _ in range(self._num_outputs):
|
||||
actual.append(sess.run(get_next_op))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next_op)
|
||||
self._match(expected, actual)
|
||||
|
||||
def testRestoreInModifiedGraph(self):
|
||||
expected = []
|
||||
actual_without_restore = []
|
||||
actual = []
|
||||
break_point = 10
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next_op, saver = self._build_graph(multiplier=15.0)
|
||||
with self.test_session(graph=g) as sess:
|
||||
sess.run(init_op)
|
||||
for _ in range(break_point):
|
||||
expected.append(sess.run(get_next_op))
|
||||
actual.extend(expected)
|
||||
self._save(sess, saver)
|
||||
for _ in range(self._num_outputs - break_point):
|
||||
expected.append(sess.run(get_next_op))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next_op)
|
||||
|
||||
# Collect outputs by running modified graph.
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next_op, saver = self._build_graph(multiplier=30.0)
|
||||
with self.test_session(graph=g) as sess:
|
||||
sess.run(init_op)
|
||||
for _ in range(self._num_outputs):
|
||||
actual_without_restore.append(sess.run(get_next_op))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next_op)
|
||||
|
||||
# Restore the checkpoint in the modified graph.
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next_op, saver = self._build_graph(multiplier=30.0)
|
||||
with self.test_session(graph=g) as sess:
|
||||
self._restore(saver, sess)
|
||||
for _ in range(self._num_outputs - break_point):
|
||||
actual.append(sess.run(get_next_op))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next_op)
|
||||
|
||||
# Ensure the modified graph gets overridden when restoring checkpoint.
|
||||
self._does_not_match(expected, actual_without_restore)
|
||||
# Expect that the outputs are what we would expect if we ran the old
|
||||
# graph.
|
||||
self._match(expected, actual)
|
||||
|
||||
# TODO(srbs): Add this test to dataset_serialization_test_base.py.
|
||||
def testRestoreInEmptyGraph(self):
|
||||
expected = []
|
||||
actual = []
|
||||
break_point = 10
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next_op, saver = self._build_graph(multiplier=15.0)
|
||||
with self.test_session(graph=g) as sess:
|
||||
sess.run(init_op)
|
||||
for _ in range(break_point):
|
||||
sess.run(get_next_op)
|
||||
self._save(sess, saver)
|
||||
for _ in range(self._num_outputs - break_point):
|
||||
expected.append(sess.run(get_next_op))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next_op)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
ds = self._build_ds()
|
||||
output_types = ds.output_types
|
||||
output_shapes = ds.output_shapes
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
get_next_op, saver = self._build_empty_graph(output_types, output_shapes)
|
||||
with self.test_session(graph=g) as sess:
|
||||
self._restore(saver, sess)
|
||||
for _ in range(self._num_outputs - break_point):
|
||||
actual.append(sess.run(get_next_op))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next_op)
|
||||
|
||||
# Expect that the outputs are what we would expect if we ran the old
|
||||
# graph.
|
||||
self._match(expected, actual)
|
||||
|
||||
def testDoNotBuildSaveable(self):
|
||||
break_point = 10
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next_op, saver = self._build_graph(multiplier=15.0)
|
||||
with self.test_session(graph=g) as sess:
|
||||
sess.run(init_op)
|
||||
for _ in range(break_point):
|
||||
sess.run(get_next_op)
|
||||
self._save(sess, saver)
|
||||
|
||||
expected = []
|
||||
# Collect ground truth by running modified graph.
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next_op, saver = self._build_graph(multiplier=30.0)
|
||||
with self.test_session(graph=g) as sess:
|
||||
sess.run(init_op)
|
||||
for _ in range(self._num_outputs):
|
||||
expected.append(sess.run(get_next_op))
|
||||
|
||||
actual = []
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next_op, saver = self._build_graph(
|
||||
multiplier=30.0, build_saveable=False)
|
||||
with self.test_session(graph=g) as sess:
|
||||
# Since the SaveableObject was not added to Saver's list
|
||||
# of saveables, iterator state is not restored by saver.restore().
|
||||
self._restore(saver, sess)
|
||||
with self.assertRaises(errors.FailedPreconditionError):
|
||||
sess.run(get_next_op)
|
||||
sess.run(init_op)
|
||||
for _ in range(self._num_outputs):
|
||||
actual.append(sess.run(get_next_op))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next_op)
|
||||
self._match(expected, actual)
|
||||
|
||||
def testSaveStatefulFunction(self):
|
||||
|
||||
def _build_ds():
|
||||
|
||||
def _map_fn(x):
|
||||
return random_ops.random_uniform(
|
||||
(), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x)
|
||||
|
||||
return dataset_ops.Dataset.range(100).map(_map_fn)
|
||||
|
||||
def _build_graph():
|
||||
ds = _build_ds()
|
||||
iterator = ds.make_initializable_iterator()
|
||||
|
||||
saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
|
||||
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
saver = saver_lib.Saver(allow_empty=True)
|
||||
return init_op, get_next, saver
|
||||
|
||||
break_point = 10
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next_op, saver = _build_graph()
|
||||
with self.test_session(graph=g) as sess:
|
||||
sess.run(init_op)
|
||||
for _ in range(break_point):
|
||||
sess.run(get_next_op)
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self._save(sess, saver)
|
||||
|
||||
def testCaptureVariableInMapFn(self):
|
||||
|
||||
def _build_ds():
|
||||
counter_var = variable_scope.get_variable(
|
||||
"counter", (), dtypes.int32, use_resource=True)
|
||||
return (dataset_ops.Dataset.from_tensors(0).repeat(10).map(
|
||||
lambda _: counter_var.assign_add(1)))
|
||||
|
||||
def _build_graph():
|
||||
ds = _build_ds()
|
||||
iterator = ds.make_initializable_iterator()
|
||||
|
||||
saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
|
||||
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
saver = saver_lib.Saver(allow_empty=True)
|
||||
return init_op, get_next, saver
|
||||
|
||||
break_point = 10
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next_op, saver = _build_graph()
|
||||
with self.test_session(graph=g) as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
sess.run(init_op)
|
||||
for _ in range(break_point):
|
||||
sess.run(get_next_op)
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self._save(sess, saver)
|
||||
|
||||
def testCaptureDefunInMapFn(self):
|
||||
num_outputs = 100
|
||||
|
||||
def _build_ds():
|
||||
|
||||
@function.Defun(dtypes.int64)
|
||||
def defun_fn(x):
|
||||
return constant_op.constant(1000) + math_ops.to_int32(x)
|
||||
|
||||
return dataset_ops.Dataset.range(num_outputs).map(defun_fn)
|
||||
|
||||
def _build_graph():
|
||||
ds = _build_ds()
|
||||
iterator = ds.make_initializable_iterator()
|
||||
|
||||
saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
|
||||
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
saver = saver_lib.Saver(allow_empty=True)
|
||||
return init_op, get_next, saver
|
||||
|
||||
break_point = 10
|
||||
expected = []
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next_op, saver = _build_graph()
|
||||
with self.test_session(graph=g) as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
sess.run(init_op)
|
||||
for _ in range(break_point):
|
||||
sess.run(get_next_op)
|
||||
self._save(sess, saver)
|
||||
for _ in range(num_outputs - break_point):
|
||||
expected.append(sess.run(get_next_op))
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
ds = _build_ds()
|
||||
output_types = ds.output_types
|
||||
output_shapes = ds.output_shapes
|
||||
|
||||
actual = []
|
||||
with ops.Graph().as_default() as g:
|
||||
get_next_op, saver = self._build_empty_graph(output_types, output_shapes)
|
||||
with self.test_session(graph=g) as sess:
|
||||
self._restore(saver, sess)
|
||||
for _ in range(num_outputs - break_point):
|
||||
actual.append(sess.run(get_next_op))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next_op)
|
||||
|
||||
self.assertSequenceEqual(expected, actual)
|
||||
|
||||
def testBuildDefunInMapFn(self):
|
||||
num_outputs = 100
|
||||
|
||||
def _build_ds():
|
||||
|
||||
@function.Defun(dtypes.int64)
|
||||
def defun_fn(x):
|
||||
|
||||
@function.Defun(dtypes.int32)
|
||||
def defun_fn_deep(x):
|
||||
return constant_op.constant(1000) + math_ops.to_int32(x)
|
||||
|
||||
return constant_op.constant(11000) + defun_fn_deep(math_ops.to_int32(x))
|
||||
|
||||
return dataset_ops.Dataset.range(num_outputs).map(defun_fn)
|
||||
|
||||
def _build_graph():
|
||||
ds = _build_ds()
|
||||
iterator = ds.make_initializable_iterator()
|
||||
|
||||
saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
|
||||
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
saver = saver_lib.Saver(allow_empty=True)
|
||||
return init_op, get_next, saver
|
||||
|
||||
break_point = 10
|
||||
expected = []
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next_op, saver = _build_graph()
|
||||
with self.test_session(graph=g) as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
sess.run(init_op)
|
||||
for _ in range(break_point):
|
||||
sess.run(get_next_op)
|
||||
self._save(sess, saver)
|
||||
for _ in range(num_outputs - break_point):
|
||||
expected.append(sess.run(get_next_op))
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
ds = _build_ds()
|
||||
output_types = ds.output_types
|
||||
output_shapes = ds.output_shapes
|
||||
|
||||
actual = []
|
||||
with ops.Graph().as_default() as g:
|
||||
get_next_op, saver = self._build_empty_graph(output_types, output_shapes)
|
||||
with self.test_session(graph=g) as sess:
|
||||
self._restore(saver, sess)
|
||||
for _ in range(num_outputs - break_point):
|
||||
actual.append(sess.run(get_next_op))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next_op)
|
||||
|
||||
self.assertSequenceEqual(expected, actual)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -50,21 +50,22 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow/contrib/data/python/ops:prefetching_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:dataset_ops_gen",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python/data/ops:iterator_ops",
|
||||
"//tensorflow/python/data/util:nest",
|
||||
"//tensorflow/python/eager:context",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
cuda_py_test(
|
||||
name = "datasets_test",
|
||||
srcs = ["datasets_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
additional_deps = [
|
||||
":datasets",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
@ -240,6 +241,7 @@ py_test(
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python/eager:function",
|
||||
"//tensorflow/python/eager:test",
|
||||
],
|
||||
)
|
||||
|
@ -20,11 +20,15 @@ from __future__ import print_function
|
||||
|
||||
import threading
|
||||
|
||||
from tensorflow.contrib.data.python.ops import prefetching_ops
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_dataset_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
|
||||
@ -32,12 +36,12 @@ _uid_counter = 0
|
||||
_uid_lock = threading.Lock()
|
||||
|
||||
|
||||
def _iterator_shared_name():
|
||||
def _generate_shared_name(prefix):
|
||||
with _uid_lock:
|
||||
global _uid_counter
|
||||
uid = _uid_counter
|
||||
_uid_counter += 1
|
||||
return "eager_iterator_{}".format(uid)
|
||||
return "{}_{}".format(prefix, uid)
|
||||
|
||||
|
||||
class Iterator(object):
|
||||
@ -72,11 +76,12 @@ class Iterator(object):
|
||||
with ops.device("/device:CPU:0"):
|
||||
ds_variant = dataset._as_variant_tensor() # pylint: disable=protected-access
|
||||
self._output_types = dataset.output_types
|
||||
self._output_shapes = dataset.output_shapes
|
||||
self._flat_output_types = nest.flatten(dataset.output_types)
|
||||
self._flat_output_shapes = nest.flatten(dataset.output_shapes)
|
||||
self._resource = gen_dataset_ops.iterator(
|
||||
container="",
|
||||
shared_name=_iterator_shared_name(),
|
||||
shared_name=_generate_shared_name("eager_iterator"),
|
||||
output_types=self._flat_output_types,
|
||||
output_shapes=self._flat_output_shapes)
|
||||
gen_dataset_ops.make_iterator(ds_variant, self._resource)
|
||||
@ -84,6 +89,35 @@ class Iterator(object):
|
||||
self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
|
||||
handle=self._resource, handle_device="/device:CPU:0")
|
||||
self._device = context.context().device_name
|
||||
self._buffer_resource_handle = None
|
||||
if not context.context().device_spec.device_type:
|
||||
is_remote_device = False
|
||||
else:
|
||||
is_remote_device = context.context().device_spec.device_type != "CPU"
|
||||
if is_remote_device:
|
||||
with ops.device("/device:CPU:0"):
|
||||
iter_string_handle = gen_dataset_ops.iterator_to_string_handle(
|
||||
self._resource)
|
||||
|
||||
@function.Defun(dtypes.string)
|
||||
def remote_fn(h):
|
||||
remote_iterator = iterator_ops.Iterator.from_string_handle(
|
||||
h, self._output_types, self._output_shapes)
|
||||
return remote_iterator.get_next()
|
||||
|
||||
remote_fn.add_to_graph(None)
|
||||
target = constant_op.constant("/device:CPU:0")
|
||||
with ops.device(self._device):
|
||||
self._buffer_resource_handle = prefetching_ops.function_buffering_resource(
|
||||
string_arg=iter_string_handle,
|
||||
f=remote_fn,
|
||||
target_device=target,
|
||||
buffer_size=10,
|
||||
thread_pool_size=1,
|
||||
container="",
|
||||
shared_name=_generate_shared_name("function_buffer_resource"))
|
||||
self._buffer_resource_deleter = resource_variable_ops.EagerResourceDeleter(
|
||||
handle=self._buffer_resource_handle, handle_device=self._device)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
@ -93,20 +127,20 @@ class Iterator(object):
|
||||
|
||||
def next(self):
|
||||
"""Return the next tf.Tensor from the dataset."""
|
||||
try:
|
||||
# TODO(ashankar): Consider removing this ops.device() contextmanager
|
||||
# and instead mimic ops placement in graphs: Operations on resource
|
||||
# handles execute on the same device as where the resource is placed.
|
||||
with ops.device("/device:CPU:0"):
|
||||
ret = gen_dataset_ops.iterator_get_next(
|
||||
self._resource,
|
||||
output_types=self._flat_output_types,
|
||||
output_shapes=self._flat_output_shapes)
|
||||
except errors.OutOfRangeError:
|
||||
raise StopIteration
|
||||
# Copies tensors from CPU to the current device if necessary.
|
||||
# TODO(rohanj): This should be replaced by the mechanism to have the
|
||||
# runtime's threads copy tensors to the destination device.
|
||||
with ops.device(self._device):
|
||||
ret = [array_ops.identity(x) for x in ret]
|
||||
try:
|
||||
if self._buffer_resource_handle is not None:
|
||||
ret = prefetching_ops.function_buffering_resource_get_next(
|
||||
function_buffer_resource=self._buffer_resource_handle,
|
||||
output_types=self._flat_output_types)
|
||||
else:
|
||||
# TODO(ashankar): Consider removing this ops.device() contextmanager
|
||||
# and instead mimic ops placement in graphs: Operations on resource
|
||||
# handles execute on the same device as where the resource is placed.
|
||||
ret = gen_dataset_ops.iterator_get_next(
|
||||
self._resource,
|
||||
output_types=self._flat_output_types,
|
||||
output_shapes=self._flat_output_shapes)
|
||||
except errors.OutOfRangeError:
|
||||
raise StopIteration
|
||||
return nest.pack_sequence_as(self._output_types, ret)
|
||||
|
@ -87,7 +87,7 @@ class EvaluatorTest(test.TestCase):
|
||||
|
||||
e.all_metric_results(logdir)
|
||||
|
||||
events = summary_test_util.events_from_file(logdir)
|
||||
events = summary_test_util.events_from_logdir(logdir)
|
||||
self.assertEqual(len(events), 2)
|
||||
self.assertEqual(events[1].summary.value[0].simple_value, 6.0)
|
||||
|
||||
@ -136,7 +136,7 @@ class EvaluatorTest(test.TestCase):
|
||||
variables.global_variables_initializer().run()
|
||||
e.run_evaluation(init_op, call_op, results_op)
|
||||
|
||||
events = summary_test_util.events_from_file(logdir)
|
||||
events = summary_test_util.events_from_logdir(logdir)
|
||||
self.assertEqual(len(events), 2)
|
||||
self.assertEqual(events[1].summary.value[0].simple_value, 6.0)
|
||||
|
||||
|
@ -95,7 +95,7 @@ class ResNet50GraphTest(tf.test.TestCase):
|
||||
sess.run([train_op, tf.contrib.summary.all_summary_ops()],
|
||||
feed_dict={images: np_images, labels: np_labels})
|
||||
|
||||
events = summary_test_util.events_from_file(logdir)
|
||||
events = summary_test_util.events_from_logdir(logdir)
|
||||
self.assertEqual(len(events), 2)
|
||||
self.assertEqual(events[1].summary.value[0].tag, 'loss')
|
||||
|
||||
|
@ -103,7 +103,7 @@ class ResNet50Test(tf.test.TestCase):
|
||||
images, labels = random_batch(2)
|
||||
train_one_step(model, images, labels, optimizer)
|
||||
self.assertEqual(320, len(model.variables))
|
||||
events = summary_test_util.events_from_file(logdir)
|
||||
events = summary_test_util.events_from_logdir(logdir)
|
||||
self.assertEqual(len(events), 2)
|
||||
self.assertEqual(events[1].summary.value[0].tag, 'loss')
|
||||
|
||||
|
@ -72,7 +72,7 @@ class MetricsTest(test.TestCase):
|
||||
name="t0").as_default(), summary_ops.always_record_summaries():
|
||||
m.result() # As a side-effect will write summaries.
|
||||
|
||||
events = summary_test_util.events_from_file(logdir)
|
||||
events = summary_test_util.events_from_logdir(logdir)
|
||||
self.assertEqual(len(events), 2)
|
||||
self.assertEqual(events[1].summary.value[0].simple_value, 37.0)
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -19,9 +19,12 @@ from __future__ import print_function
|
||||
import gc
|
||||
|
||||
from tensorflow.contrib.eager.python import network
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.layers import core
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -46,8 +49,8 @@ class NetworkTest(test.TestCase):
|
||||
|
||||
def _save_modify_load_network_built(self, net, global_step=None):
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_path = net.save(
|
||||
save_path=checkpoint_directory, global_step=global_step)
|
||||
checkpoint_path = network.save_network_checkpoint(
|
||||
network=net, save_path=checkpoint_directory, global_step=global_step)
|
||||
input_value = constant_op.constant([[42.0]])
|
||||
original_output = self.evaluate(net(input_value))
|
||||
for var in net.variables:
|
||||
@ -56,13 +59,13 @@ class NetworkTest(test.TestCase):
|
||||
self.evaluate(net(input_value)),
|
||||
original_output)
|
||||
# Either the returned explicit checkpoint path or the directory should work.
|
||||
net.restore(save_path=checkpoint_directory)
|
||||
network.restore_network_checkpoint(net, save_path=checkpoint_directory)
|
||||
self.assertAllEqual(
|
||||
original_output,
|
||||
self.evaluate(net(input_value)))
|
||||
for var in net.variables:
|
||||
self.evaluate(var.assign(var + 2.))
|
||||
net.restore(save_path=checkpoint_path)
|
||||
network.restore_network_checkpoint(net, save_path=checkpoint_path)
|
||||
self.assertAllEqual(
|
||||
original_output,
|
||||
self.evaluate(net(input_value)))
|
||||
@ -85,13 +88,30 @@ class NetworkTest(test.TestCase):
|
||||
result = net(constant_op.constant([[2.0]]))
|
||||
self.assertEqual(34.0, self.evaluate(result))
|
||||
|
||||
# TODO(akshayka): This test should be changed once an API for compiling
|
||||
# `call` into a defun is implemented.
|
||||
def testReplacingNetworkCallWithDefun(self):
|
||||
net = MyNetwork(name="abcd")
|
||||
x = constant_op.constant([[2.0]])
|
||||
net(x) # Force variables to be created.
|
||||
self.evaluate(net.trainable_variables[0].assign([[17.0]]))
|
||||
|
||||
net.call = function.defun(net.call)
|
||||
result = net(x) # Build and execute the TensorFlow function
|
||||
self.assertEqual(34.0, self.evaluate(result))
|
||||
|
||||
# Force the creation of another TensorFlow function by changing input shape
|
||||
y = constant_op.constant([[1.0], [2.0]])
|
||||
result = net(y)
|
||||
self.assertAllEqual([[17.0], [34.0]], self.evaluate(result))
|
||||
|
||||
# TODO(allenl): This test creates garbage in some Python versions
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testNetworkSaveRestoreAlreadyBuilt(self):
|
||||
net = MyNetwork(name="abcd")
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Attempt to save the Network before it was first called"):
|
||||
net.save(self.get_temp_dir())
|
||||
network.save_network_checkpoint(net, self.get_temp_dir())
|
||||
net(constant_op.constant([[2.0]]))
|
||||
self.evaluate(net.trainable_variables[0].assign([[17.0]]))
|
||||
self._save_modify_load_network_built(net, global_step=None)
|
||||
@ -105,7 +125,7 @@ class NetworkTest(test.TestCase):
|
||||
self.evaluate(net.variables[0].assign([[3.]]))
|
||||
default_global_step = training_util.get_or_create_global_step()
|
||||
self.evaluate(default_global_step.assign(4242))
|
||||
save_path = net.save(self.get_temp_dir())
|
||||
save_path = network.save_network_checkpoint(net, self.get_temp_dir())
|
||||
self.assertIn("abcd-4242", save_path)
|
||||
|
||||
# TODO(allenl): This test creates garbage in some Python versions
|
||||
@ -116,16 +136,43 @@ class NetworkTest(test.TestCase):
|
||||
test_input = constant_op.constant([[2.0]])
|
||||
net1(test_input)
|
||||
self.evaluate(net1.trainable_variables[0].assign([[17.0]]))
|
||||
save_path = net1.save(save_dir)
|
||||
save_path = network.save_network_checkpoint(net1, save_dir)
|
||||
# With a pre-build restore we should have the same value.
|
||||
net2 = MyNetwork()
|
||||
net2.restore(save_path)
|
||||
network.restore_network_checkpoint(net2, save_path)
|
||||
self.assertAllEqual(self.evaluate(net1(test_input)),
|
||||
self.evaluate(net2(test_input)))
|
||||
self.assertIsNot(net1.variables[0], net2.variables[0])
|
||||
self.assertAllEqual(self.evaluate(net1.variables[0]),
|
||||
self.evaluate(net2.variables[0]))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testNetworkMatchesLayerVariableNames(self):
|
||||
zero = constant_op.constant([[0.]])
|
||||
layer_one = core.Dense(1, use_bias=False)
|
||||
layer_one(zero)
|
||||
layer_two = core.Dense(1, use_bias=False)
|
||||
layer_two(zero)
|
||||
|
||||
class TwoLayerNet(network.Network):
|
||||
|
||||
def __init__(self, name=None):
|
||||
super(TwoLayerNet, self).__init__(name=name)
|
||||
self.first = self.track_layer(core.Dense(
|
||||
1, use_bias=False))
|
||||
self.second = self.track_layer(core.Dense(
|
||||
1, use_bias=False))
|
||||
|
||||
def call(self, x):
|
||||
return self.second(self.first(x))
|
||||
|
||||
net = TwoLayerNet()
|
||||
net(zero)
|
||||
self.assertEqual("two_layer_net/" + layer_one.variables[0].name,
|
||||
net.first.variables[0].name)
|
||||
self.assertEqual("two_layer_net/" + layer_two.variables[0].name,
|
||||
net.second.variables[0].name)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testLoadIntoUnbuiltSharedLayer(self):
|
||||
|
||||
@ -173,14 +220,15 @@ class NetworkTest(test.TestCase):
|
||||
# Re-map the variable names so that with default restore mapping we'll
|
||||
# attempt to restore into the unbuilt Layer.
|
||||
name_mapping = {
|
||||
"checkpoint_creator/first_layer/kernel": "owner_1/first_layer/kernel",
|
||||
"checkpoint_creator/first_layer/kernel": "owner/first_layer/kernel",
|
||||
"checkpoint_creator/second_layer/kernel": "second_layer/kernel",
|
||||
}
|
||||
save_path = checkpoint_creator.save(
|
||||
save_path = network.save_network_checkpoint(
|
||||
checkpoint_creator,
|
||||
self.get_temp_dir(),
|
||||
map_func=lambda full_name: name_mapping[full_name])
|
||||
load_into = User(use_layer=first_owner.first)
|
||||
load_into.restore(save_path)
|
||||
network.restore_network_checkpoint(load_into, save_path)
|
||||
self.assertEqual(0, len(first_owner.variables))
|
||||
self.assertAllEqual(self.evaluate(checkpoint_creator(one)),
|
||||
self.evaluate(load_into(one)))
|
||||
@ -196,12 +244,13 @@ class NetworkTest(test.TestCase):
|
||||
del first_owner
|
||||
gc.collect()
|
||||
def _restore_map_func(original_name):
|
||||
if original_name.startswith("owner_1"):
|
||||
return original_name.replace("owner_1", "owner_2")
|
||||
if original_name.startswith("owner/"):
|
||||
return original_name.replace("owner/", "owner_1/")
|
||||
else:
|
||||
return "user_2/" + original_name
|
||||
return "user_1/" + original_name
|
||||
with self.assertRaisesRegexp(ValueError, "garbage collected"):
|
||||
load_into.restore(save_path, map_func=_restore_map_func)
|
||||
network.restore_network_checkpoint(
|
||||
load_into, save_path, map_func=_restore_map_func)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testRestoreIntoSubNetwork(self):
|
||||
@ -221,17 +270,18 @@ class NetworkTest(test.TestCase):
|
||||
whole_model_saver(one)
|
||||
self.evaluate(whole_model_saver.variables[0].assign([[15.]]))
|
||||
self.evaluate(whole_model_saver.variables[1].assign([[16.]]))
|
||||
whole_model_checkpoint = whole_model_saver.save(self.get_temp_dir())
|
||||
whole_model_checkpoint = network.save_network_checkpoint(
|
||||
whole_model_saver, self.get_temp_dir())
|
||||
|
||||
save_from = MyNetwork()
|
||||
save_from(one)
|
||||
self.evaluate(save_from.variables[0].assign([[5.]]))
|
||||
checkpoint = save_from.save(self.get_temp_dir())
|
||||
checkpoint = network.save_network_checkpoint(save_from, self.get_temp_dir())
|
||||
save_into_parent = Parent()
|
||||
save_into_parent.restore(whole_model_checkpoint)
|
||||
save_into_parent.first.restore(checkpoint)
|
||||
save_into_parent.first.restore(checkpoint) # deferred loading multiple
|
||||
# times is fine
|
||||
network.restore_network_checkpoint(save_into_parent, whole_model_checkpoint)
|
||||
network.restore_network_checkpoint(save_into_parent.first, checkpoint)
|
||||
# deferred loading multiple times is fine
|
||||
network.restore_network_checkpoint(save_into_parent.first, checkpoint)
|
||||
save_into_parent(one) # deferred loading
|
||||
self.assertAllEqual([[5.]], self.evaluate(save_into_parent.variables[0]))
|
||||
self.assertAllEqual([[16.]], self.evaluate(save_into_parent.variables[1]))
|
||||
@ -240,9 +290,9 @@ class NetworkTest(test.TestCase):
|
||||
# (deferred restoration should happen the same way non-deferred happens,
|
||||
# with later restorations overwriting older ones).
|
||||
save_into_parent = Parent()
|
||||
save_into_parent.first.restore(checkpoint) # deferred loading multiple
|
||||
# times is fine
|
||||
save_into_parent.restore(whole_model_checkpoint)
|
||||
# deferred loading multiple times is fine
|
||||
network.restore_network_checkpoint(save_into_parent.first, checkpoint)
|
||||
network.restore_network_checkpoint(save_into_parent, whole_model_checkpoint)
|
||||
save_into_parent(one) # deferred loading
|
||||
# We've overwritten the sub-Network restore.
|
||||
self.assertAllEqual([[15.]], self.evaluate(save_into_parent.variables[0]))
|
||||
@ -250,12 +300,12 @@ class NetworkTest(test.TestCase):
|
||||
|
||||
self.evaluate(save_into_parent.variables[0].assign([[3.]]))
|
||||
self.evaluate(save_into_parent.variables[1].assign([[4.]]))
|
||||
save_into_parent.second.restore(checkpoint)
|
||||
network.restore_network_checkpoint(save_into_parent.second, checkpoint)
|
||||
self.assertAllEqual([[5.]], self.evaluate(save_into_parent.variables[1]))
|
||||
with self.assertRaisesRegexp(errors_impl.NotFoundError,
|
||||
"not found in checkpoint"):
|
||||
# The checkpoint is incompatible.
|
||||
save_into_parent.restore(checkpoint)
|
||||
network.restore_network_checkpoint(save_into_parent, checkpoint)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testCustomMapCollisionErrors(self):
|
||||
@ -277,31 +327,36 @@ class NetworkTest(test.TestCase):
|
||||
self.evaluate(make_checkpoint.variables[1].assign([[3.]]))
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
"The map_func passed to Network.save for the Network 'parent_1' "
|
||||
"resulted in two variables named 'foo'"):
|
||||
make_checkpoint.save(self.get_temp_dir(), map_func=lambda n: "foo")
|
||||
checkpoint = make_checkpoint.first.save(
|
||||
self.get_temp_dir(), map_func=lambda n: "foo")
|
||||
"The map_func passed to save_network_checkpoint for the Network "
|
||||
"'parent' resulted in two variables named 'foo'"):
|
||||
network.save_network_checkpoint(
|
||||
make_checkpoint, self.get_temp_dir(), map_func=lambda n: "foo")
|
||||
checkpoint = network.save_network_checkpoint(
|
||||
network=make_checkpoint.first,
|
||||
save_path=self.get_temp_dir(),
|
||||
map_func=lambda n: "foo")
|
||||
loader = Parent()
|
||||
loader.restore(checkpoint, map_func=lambda n: "foo")
|
||||
network.restore_network_checkpoint(
|
||||
loader, checkpoint, map_func=lambda n: "foo")
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
("The map_func passed to Network.restore for the Network"
|
||||
" 'parent_2' resulted in two variables named 'foo'")):
|
||||
("The map_func passed to restore_network_checkpoint for the Network"
|
||||
" 'parent_1' resulted in two variables named 'foo'")):
|
||||
loader(one)
|
||||
loader = Parent()
|
||||
loader(one)
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
("The map_func passed to Network.restore for the Network"
|
||||
" 'parent_3' resulted in two variables named 'foo'")):
|
||||
loader.restore(checkpoint, map_func=lambda n: "foo")
|
||||
("The map_func passed to restore_network_checkpoint for the Network"
|
||||
" 'parent_2' resulted in two variables named 'foo'")):
|
||||
network.restore_network_checkpoint(
|
||||
loader, checkpoint, map_func=lambda n: "foo")
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testDefaultMapCollisionErrors(self):
|
||||
|
||||
one = constant_op.constant([[1.]])
|
||||
first = core.Dense(1, name="dense_1", use_bias=False)
|
||||
first = core.Dense(1, name="dense", use_bias=False)
|
||||
first(one)
|
||||
|
||||
class Parent(network.Network):
|
||||
@ -322,8 +377,8 @@ class NetworkTest(test.TestCase):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
("The default checkpoint variable name mapping strategy for Network "
|
||||
"'parent_1' resulted in a naming conflict.")):
|
||||
make_checkpoint.save(self.get_temp_dir())
|
||||
"'parent' resulted in a naming conflict.")):
|
||||
network.save_network_checkpoint(make_checkpoint, self.get_temp_dir())
|
||||
|
||||
class Compatible(network.Network):
|
||||
|
||||
@ -337,14 +392,15 @@ class NetworkTest(test.TestCase):
|
||||
successful_checkpoint = Compatible()
|
||||
successful_checkpoint(one)
|
||||
self.evaluate(successful_checkpoint.variables[0].assign([[-1.]]))
|
||||
checkpoint_path = successful_checkpoint.save(self.get_temp_dir())
|
||||
checkpoint_path = network.save_network_checkpoint(
|
||||
successful_checkpoint, self.get_temp_dir())
|
||||
load_checkpoint = Parent()
|
||||
load_checkpoint(one)
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
("The default checkpoint variable name mapping strategy for Network "
|
||||
"'parent_2' resulted in a naming conflict.")):
|
||||
load_checkpoint.restore(checkpoint_path)
|
||||
"'parent_1' resulted in a naming conflict.")):
|
||||
network.restore_network_checkpoint(load_checkpoint, checkpoint_path)
|
||||
|
||||
def testNoReferenceCyclesAfterCall(self):
|
||||
|
||||
@ -398,6 +454,36 @@ class NetworkTest(test.TestCase):
|
||||
self.assertIsInstance(net.trainable_weights[0],
|
||||
resource_variable_ops.ResourceVariable)
|
||||
|
||||
def testGraphOpNames(self):
|
||||
"""Network operation names should match variable naming."""
|
||||
|
||||
def _check_op_prefixes(expected_prefix, checked_ops):
|
||||
for operation in ops.get_default_graph().get_operations():
|
||||
if operation.name == "ignore":
|
||||
continue
|
||||
if operation.name in checked_ops:
|
||||
continue
|
||||
checked_ops.add(operation.name)
|
||||
self.assertStartsWith(expected_start=expected_prefix,
|
||||
actual=operation.name)
|
||||
self.assertNotIn("my_network", operation.name[len(expected_prefix):])
|
||||
self.assertNotIn("dense", operation.name[len(expected_prefix):])
|
||||
|
||||
with context.graph_mode():
|
||||
net = MyNetwork()
|
||||
zero = constant_op.constant([[0.]], name="ignore")
|
||||
net(zero)
|
||||
checked_ops = set()
|
||||
_check_op_prefixes(expected_prefix="my_network/dense/",
|
||||
checked_ops=checked_ops)
|
||||
net.net2 = net.track_layer(MyNetwork())
|
||||
net.net2(zero)
|
||||
_check_op_prefixes(expected_prefix="my_network/my_network/dense/",
|
||||
checked_ops=checked_ops)
|
||||
MyNetwork()(zero)
|
||||
_check_op_prefixes(expected_prefix="my_network_1/dense/",
|
||||
checked_ops=checked_ops)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
|
||||
def testDuplicateNameError(self):
|
||||
one = constant_op.constant([[1.]])
|
||||
@ -414,25 +500,25 @@ class NetworkTest(test.TestCase):
|
||||
# Naming happens in the order of first build rather than the order of
|
||||
# construction, but for clarity they're the same here and construction is
|
||||
# annotated.
|
||||
outside_net_before = MyNetwork() # name=my_network_1
|
||||
outside_net_before = MyNetwork() # name=my_network
|
||||
outside_net_before(one)
|
||||
captured_scope = variable_scope.get_variable_scope()
|
||||
with variable_scope.variable_scope("outside_scope"):
|
||||
net1 = MyNetwork() # name=outside_scope/my_network_1
|
||||
net1 = MyNetwork() # name=outside_scope/my_network
|
||||
net1(one)
|
||||
name_conflict1 = MyNetwork(name="name_conflict") # fine, unique so far
|
||||
name_conflict2 = MyNetwork(name="name_conflict") # error on build
|
||||
with variable_scope.variable_scope("inside_scope"):
|
||||
# No issue here since the name is unique within its scope.
|
||||
name_conflict3 = MyNetwork(name="name_conflict")
|
||||
net2 = MyNetwork() # name=outside_scope/my_network_3 to avoid the
|
||||
# variable_scope my_network_2 below.
|
||||
net2 = MyNetwork() # name=outside_scope/my_network_2 to avoid the
|
||||
# variable_scope my_network_1 below.
|
||||
vs_name_conflict = MyNetwork(name="vs_name_conflict") # conflict below
|
||||
with variable_scope.variable_scope("intervening_scope"):
|
||||
with variable_scope.variable_scope(captured_scope):
|
||||
with variable_scope.variable_scope("outside_scope"):
|
||||
name_conflict4 = MyNetwork(name="name_conflict") # error on build
|
||||
with variable_scope.variable_scope("my_network_2"):
|
||||
with variable_scope.variable_scope("my_network_1"):
|
||||
pass
|
||||
with variable_scope.variable_scope("vs_name_conflict"):
|
||||
pass
|
||||
@ -452,35 +538,35 @@ class NetworkTest(test.TestCase):
|
||||
self.assertEqual("outside_scope/name_conflict",
|
||||
name_conflict1.name)
|
||||
self.assertStartsWith(
|
||||
expected_start="outside_scope/name_conflict/dense_1/",
|
||||
expected_start="outside_scope/name_conflict/dense/",
|
||||
actual=name_conflict1.variables[0].name)
|
||||
self.assertEqual("outside_scope/inside_scope/name_conflict",
|
||||
name_conflict3.name)
|
||||
self.assertStartsWith(
|
||||
expected_start="outside_scope/inside_scope/name_conflict/dense_1/",
|
||||
expected_start="outside_scope/inside_scope/name_conflict/dense/",
|
||||
actual=name_conflict3.variables[0].name)
|
||||
self.assertEqual("outside_scope/my_network_1", net1.name)
|
||||
self.assertEqual("outside_scope/my_network", net1.name)
|
||||
self.assertStartsWith(
|
||||
expected_start="outside_scope/my_network_1/dense_1/",
|
||||
expected_start="outside_scope/my_network/dense/",
|
||||
actual=net1.trainable_weights[0].name)
|
||||
self.assertEqual("outside_scope/my_network_3", net2.name)
|
||||
self.assertEqual("outside_scope/my_network_2", net2.name)
|
||||
self.assertStartsWith(
|
||||
expected_start="outside_scope/my_network_3/dense_1/",
|
||||
expected_start="outside_scope/my_network_2/dense/",
|
||||
actual=net2.trainable_weights[0].name)
|
||||
net3(one)
|
||||
self.assertEqual("outside_scope/my_network_4", net3.name)
|
||||
self.assertEqual("outside_scope/my_network_3", net3.name)
|
||||
self.assertStartsWith(
|
||||
expected_start="outside_scope/my_network_4/dense_1/",
|
||||
expected_start="outside_scope/my_network_3/dense/",
|
||||
actual=net3.trainable_weights[0].name)
|
||||
outside_net_after = MyNetwork()
|
||||
outside_net_after(one)
|
||||
self.assertEqual("my_network_1", outside_net_before.name)
|
||||
self.assertEqual("my_network", outside_net_before.name)
|
||||
self.assertStartsWith(
|
||||
expected_start="my_network_1/dense_1/",
|
||||
expected_start="my_network/dense/",
|
||||
actual=outside_net_before.trainable_weights[0].name)
|
||||
self.assertEqual("my_network_2", outside_net_after.name)
|
||||
self.assertEqual("my_network_1", outside_net_after.name)
|
||||
self.assertStartsWith(
|
||||
expected_start="my_network_2/dense_1/",
|
||||
expected_start="my_network_1/dense/",
|
||||
actual=outside_net_after.trainable_weights[0].name)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
@ -490,21 +576,21 @@ class NetworkTest(test.TestCase):
|
||||
net = MyNetwork()
|
||||
net(constant_op.constant([[2.0]]))
|
||||
self.evaluate(net.variables[0].assign([[42.]]))
|
||||
self.assertEqual(net.name, "scope1/scope2/my_network_1")
|
||||
self.assertEqual(net.name, "scope1/scope2/my_network")
|
||||
self.assertStartsWith(
|
||||
expected_start="scope1/scope2/my_network_1/dense_1/",
|
||||
expected_start="scope1/scope2/my_network/dense/",
|
||||
actual=net.trainable_weights[0].name)
|
||||
save_path = net.save(self.get_temp_dir())
|
||||
self.assertIn("scope1_scope2_my_network_1", save_path)
|
||||
save_path = network.save_network_checkpoint(net, self.get_temp_dir())
|
||||
self.assertIn("scope1_scope2_my_network", save_path)
|
||||
restore_net = MyNetwork()
|
||||
# Delayed restoration
|
||||
restore_net.restore(save_path)
|
||||
network.restore_network_checkpoint(restore_net, save_path)
|
||||
restore_net(constant_op.constant([[1.0]]))
|
||||
self.assertAllEqual([[42.]],
|
||||
self.evaluate(restore_net.variables[0]))
|
||||
self.evaluate(restore_net.variables[0].assign([[-1.]]))
|
||||
# Immediate restoration
|
||||
restore_net.restore(save_path)
|
||||
network.restore_network_checkpoint(restore_net, save_path)
|
||||
self.assertAllEqual([[42.]],
|
||||
self.evaluate(restore_net.variables[0]))
|
||||
|
||||
@ -523,7 +609,7 @@ class NetworkTest(test.TestCase):
|
||||
one = constant_op.constant([[1.]])
|
||||
net = ParentNetwork()
|
||||
net(one)
|
||||
self.assertStartsWith(expected_start="parent_network_1/explicit_name/",
|
||||
self.assertStartsWith(expected_start="parent_network/explicit_name/",
|
||||
actual=net.trainable_weights[0].name)
|
||||
self.assertEqual("explicit_name", net.first.name)
|
||||
|
||||
@ -578,15 +664,15 @@ class NetworkTest(test.TestCase):
|
||||
# locally so that previous Layer consutrciton does not interfere with
|
||||
# variable naming (e.g. add a Layer construction before the Network,
|
||||
# suddenly your previously saved checkpoint is incompatible).
|
||||
self.assertEqual("dense_1", net1.l1.name)
|
||||
self.assertEqual("dense_1", net2.l1.name)
|
||||
self.assertEqual("dense", net1.l1.name)
|
||||
self.assertEqual("dense", net2.l1.name)
|
||||
self.evaluate(net1.trainable_weights[0].assign([[1.]]))
|
||||
self.evaluate(net2.trainable_weights[0].assign([[2.]]))
|
||||
self.assertEqual(2., self.evaluate(net2.trainable_weights[0]))
|
||||
self.assertEqual(1., self.evaluate(net1.trainable_weights[0]))
|
||||
self.assertStartsWith(expected_start="my_network_1/dense_1/",
|
||||
self.assertStartsWith(expected_start="my_network/dense/",
|
||||
actual=net1.trainable_weights[0].name)
|
||||
self.assertStartsWith(expected_start="my_network_2/dense_1/",
|
||||
self.assertStartsWith(expected_start="my_network_1/dense/",
|
||||
actual=net2.trainable_weights[0].name)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
@ -607,31 +693,31 @@ class NetworkTest(test.TestCase):
|
||||
one = constant_op.constant([[1.]])
|
||||
net = ParentNetwork()
|
||||
net(one)
|
||||
self.assertStartsWith(expected_start="parent_network_1/my_network_1/dense",
|
||||
self.assertStartsWith(expected_start="parent_network/my_network/dense",
|
||||
actual=net.trainable_weights[0].name)
|
||||
self.assertStartsWith(expected_start="parent_network_1/my_network_1/dense",
|
||||
self.assertStartsWith(expected_start="parent_network/my_network/dense",
|
||||
actual=net.first.trainable_weights[0].name)
|
||||
self.assertStartsWith(expected_start="parent_network_1/my_network_2/dense",
|
||||
self.assertStartsWith(expected_start="parent_network/my_network_1/dense",
|
||||
actual=net.trainable_weights[1].name)
|
||||
self.assertStartsWith(expected_start="parent_network_1/my_network_2/dense",
|
||||
self.assertStartsWith(expected_start="parent_network/my_network_1/dense",
|
||||
actual=net.second.trainable_weights[0].name)
|
||||
self.assertEqual("parent_network_1", net.name)
|
||||
self.assertEqual("my_network_1", net.first.name)
|
||||
self.assertEqual("my_network_2", net.second.name)
|
||||
self.assertEqual("parent_network", net.name)
|
||||
self.assertEqual("my_network", net.first.name)
|
||||
self.assertEqual("my_network_1", net.second.name)
|
||||
|
||||
net2 = ParentNetwork()
|
||||
net2(one)
|
||||
self.assertStartsWith(expected_start="parent_network_2/my_network_1/dense",
|
||||
self.assertStartsWith(expected_start="parent_network_1/my_network/dense",
|
||||
actual=net2.trainable_weights[0].name)
|
||||
self.assertStartsWith(expected_start="parent_network_2/my_network_1/dense",
|
||||
self.assertStartsWith(expected_start="parent_network_1/my_network/dense",
|
||||
actual=net2.first.trainable_weights[0].name)
|
||||
self.assertStartsWith(expected_start="parent_network_2/my_network_2/dense",
|
||||
self.assertStartsWith(expected_start="parent_network_1/my_network_1/dense",
|
||||
actual=net2.trainable_weights[1].name)
|
||||
self.assertStartsWith(expected_start="parent_network_2/my_network_2/dense",
|
||||
self.assertStartsWith(expected_start="parent_network_1/my_network_1/dense",
|
||||
actual=net2.second.trainable_weights[0].name)
|
||||
self.assertEqual("parent_network_2", net2.name)
|
||||
self.assertEqual("my_network_1", net2.first.name)
|
||||
self.assertEqual("my_network_2", net2.second.name)
|
||||
self.assertEqual("parent_network_1", net2.name)
|
||||
self.assertEqual("my_network", net2.first.name)
|
||||
self.assertEqual("my_network_1", net2.second.name)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testNestableExplicit(self):
|
||||
@ -692,26 +778,26 @@ class NetworkTest(test.TestCase):
|
||||
one = constant_op.constant([[1.]])
|
||||
net = MixedLayerNetwork()
|
||||
net(one)
|
||||
self.assertEqual("dense_1", net.first.name)
|
||||
self.assertEqual("dense_2", net.second.name)
|
||||
self.assertEqual("dense_3", net.third.name)
|
||||
self.assertEqual("dense_4", net.fourth.name)
|
||||
self.assertEqual("dense_5", net.fifth.name)
|
||||
self.assertEqual("dense", net.first.name)
|
||||
self.assertEqual("dense_1", net.second.name)
|
||||
self.assertEqual("dense_2", net.third.name)
|
||||
self.assertEqual("dense_3", net.fourth.name)
|
||||
self.assertEqual("dense_4", net.fifth.name)
|
||||
# Note that this is _not_ the default naming behavior for Layers. Layers
|
||||
# which are added to Networks follow Network variable naming conventions
|
||||
# (i.e. variable names = network name unless variable sharing). Nested
|
||||
# Layers revert to Layer behavior.
|
||||
self.assertStartsWith(expected_start="mixed_layer_network_1/dense_1/",
|
||||
self.assertStartsWith(expected_start="mixed_layer_network/dense/",
|
||||
actual=net.trainable_weights[0].name)
|
||||
self.assertStartsWith(expected_start="mixed_layer_network_1/dense_2/",
|
||||
self.assertStartsWith(expected_start="mixed_layer_network/dense_1/",
|
||||
actual=net.trainable_weights[1].name)
|
||||
self.assertStartsWith(expected_start="mixed_layer_network_1/dense_3/",
|
||||
self.assertStartsWith(expected_start="mixed_layer_network/dense_2/",
|
||||
actual=net.trainable_weights[2].name)
|
||||
self.assertStartsWith(expected_start="mixed_layer_network_1/dense_4/",
|
||||
self.assertStartsWith(expected_start="mixed_layer_network/dense_3/",
|
||||
actual=net.trainable_weights[3].name)
|
||||
self.assertStartsWith(expected_start="mixed_layer_network_1/dense_5/",
|
||||
self.assertStartsWith(expected_start="mixed_layer_network/dense_4/",
|
||||
actual=net.trainable_weights[4].name)
|
||||
self.assertEqual("mixed_layer_network_1", net.name)
|
||||
self.assertEqual("mixed_layer_network", net.name)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testNestableExplicitCollisions(self):
|
||||
@ -764,24 +850,24 @@ class NetworkTest(test.TestCase):
|
||||
net = ParentNetwork()
|
||||
net(one)
|
||||
self.assertStartsWith(
|
||||
expected_start="parent_network_1/first_unique_child_name/dense_1/",
|
||||
expected_start="parent_network/first_unique_child_name/dense/",
|
||||
actual=net.trainable_weights[0].name)
|
||||
self.assertStartsWith(
|
||||
expected_start="parent_network_1/second_unique_child_name/dense_1/",
|
||||
expected_start="parent_network/second_unique_child_name/dense/",
|
||||
actual=net.trainable_weights[1].name)
|
||||
self.assertEqual("parent_network_1", net.name)
|
||||
self.assertEqual("parent_network", net.name)
|
||||
self.assertEqual("first_unique_child_name", net.first.name)
|
||||
self.assertEqual("second_unique_child_name", net.second.name)
|
||||
|
||||
net2 = ParentNetwork()
|
||||
net2(one)
|
||||
self.assertStartsWith(
|
||||
expected_start="parent_network_2/first_unique_child_name/dense",
|
||||
expected_start="parent_network_1/first_unique_child_name/dense",
|
||||
actual=net2.trainable_weights[0].name)
|
||||
self.assertStartsWith(
|
||||
expected_start="parent_network_2/second_unique_child_name/dense",
|
||||
expected_start="parent_network_1/second_unique_child_name/dense",
|
||||
actual=net2.trainable_weights[1].name)
|
||||
self.assertEqual("parent_network_2", net2.name)
|
||||
self.assertEqual("parent_network_1", net2.name)
|
||||
self.assertEqual("first_unique_child_name", net2.first.name)
|
||||
self.assertEqual("second_unique_child_name", net2.second.name)
|
||||
|
||||
@ -839,15 +925,15 @@ class NetworkTest(test.TestCase):
|
||||
net2(one)
|
||||
|
||||
self.assertStartsWith(
|
||||
expected_start="first_parent_network_1/my_network_1/dense_1/",
|
||||
expected_start="first_parent_network/my_network/dense/",
|
||||
actual=net2.trainable_weights[0].name)
|
||||
self.assertStartsWith(
|
||||
expected_start="second_parent_network_1/my_network_1/dense_1/",
|
||||
expected_start="second_parent_network/my_network/dense/",
|
||||
actual=net2.trainable_weights[1].name)
|
||||
self.assertEqual("second_parent_network_1", net2.name)
|
||||
self.assertEqual("second_parent_network", net2.name)
|
||||
self.assertTrue(net2.first is net.first)
|
||||
self.assertEqual("my_network_1", net2.first.name)
|
||||
self.assertEqual("my_network_1", net2.second.name)
|
||||
self.assertEqual("my_network", net2.first.name)
|
||||
self.assertEqual("my_network", net2.second.name)
|
||||
|
||||
# No name collision; the owned Network is added first and has a different
|
||||
# name than the shared Network.
|
||||
@ -865,15 +951,15 @@ class NetworkTest(test.TestCase):
|
||||
net3(one)
|
||||
|
||||
self.assertStartsWith(
|
||||
expected_start="third_parent_network_1/my_network_1/dense",
|
||||
expected_start="third_parent_network/my_network/dense",
|
||||
actual=net3.trainable_weights[0].name)
|
||||
self.assertStartsWith(
|
||||
expected_start="first_parent_network_1/my_network_2/dense",
|
||||
expected_start="first_parent_network/my_network_1/dense",
|
||||
actual=net3.trainable_weights[1].name)
|
||||
self.assertEqual("third_parent_network_1", net3.name)
|
||||
self.assertEqual("third_parent_network", net3.name)
|
||||
self.assertTrue(net3.second is net.second)
|
||||
self.assertEqual("my_network_1", net3.first.name)
|
||||
self.assertEqual("my_network_2", net3.second.name)
|
||||
self.assertEqual("my_network", net3.first.name)
|
||||
self.assertEqual("my_network_1", net3.second.name)
|
||||
|
||||
# "Unavoidable" same-name Layer. The owned name is added first (fixed), then
|
||||
# a shared Network is added with the same name.
|
||||
@ -891,15 +977,15 @@ class NetworkTest(test.TestCase):
|
||||
net4(one)
|
||||
|
||||
self.assertStartsWith(
|
||||
expected_start="fourth_parent_network_1/my_network_1/dense_1/",
|
||||
expected_start="fourth_parent_network/my_network/dense/",
|
||||
actual=net4.trainable_weights[0].name)
|
||||
self.assertStartsWith(
|
||||
expected_start="first_parent_network_1/my_network_1/dense_1/",
|
||||
expected_start="first_parent_network/my_network/dense/",
|
||||
actual=net4.trainable_weights[1].name)
|
||||
self.assertEqual("fourth_parent_network_1", net4.name)
|
||||
self.assertEqual("fourth_parent_network", net4.name)
|
||||
self.assertTrue(net4.second is net.first)
|
||||
self.assertEqual("my_network_1", net4.first.name)
|
||||
self.assertEqual("my_network_1", net4.second.name)
|
||||
self.assertEqual("my_network", net4.first.name)
|
||||
self.assertEqual("my_network", net4.second.name)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testRecursiveLayerRenaming(self):
|
||||
@ -930,28 +1016,28 @@ class NetworkTest(test.TestCase):
|
||||
net(one)
|
||||
|
||||
self.assertStartsWith(
|
||||
expected_start=("parent_network_1/network_with_layer_children_1/"
|
||||
"dense_1/"),
|
||||
expected_start=("parent_network/network_with_layer_children/"
|
||||
"dense/"),
|
||||
actual=net.trainable_weights[0].name)
|
||||
self.assertStartsWith(
|
||||
expected_start=("parent_network_1/network_with_layer_children_1/"
|
||||
"dense_2/"),
|
||||
expected_start=("parent_network/network_with_layer_children/"
|
||||
"dense_1/"),
|
||||
actual=net.trainable_weights[1].name)
|
||||
self.assertStartsWith(
|
||||
expected_start=("parent_network_1/network_with_layer_children_2/"
|
||||
"dense_1/"),
|
||||
expected_start=("parent_network/network_with_layer_children_1/"
|
||||
"dense/"),
|
||||
actual=net.trainable_weights[2].name)
|
||||
self.assertStartsWith(
|
||||
expected_start=("parent_network_1/network_with_layer_children_2/"
|
||||
"dense_2/"),
|
||||
expected_start=("parent_network/network_with_layer_children_1/"
|
||||
"dense_1/"),
|
||||
actual=net.trainable_weights[3].name)
|
||||
self.assertEqual("parent_network_1", net.name)
|
||||
self.assertEqual("network_with_layer_children_1", net.first.name)
|
||||
self.assertEqual("network_with_layer_children_2", net.second.name)
|
||||
self.assertEqual("dense_1", net.first.first.name)
|
||||
self.assertEqual("dense_2", net.first.second.name)
|
||||
self.assertEqual("dense_1", net.second.first.name)
|
||||
self.assertEqual("dense_2", net.second.second.name)
|
||||
self.assertEqual("parent_network", net.name)
|
||||
self.assertEqual("network_with_layer_children", net.first.name)
|
||||
self.assertEqual("network_with_layer_children_1", net.second.name)
|
||||
self.assertEqual("dense", net.first.first.name)
|
||||
self.assertEqual("dense_1", net.first.second.name)
|
||||
self.assertEqual("dense", net.second.first.name)
|
||||
self.assertEqual("dense_1", net.second.second.name)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testCallInDifferentOrderThanConstruct(self):
|
||||
@ -985,23 +1071,23 @@ class NetworkTest(test.TestCase):
|
||||
net1(one)
|
||||
|
||||
self.assertStartsWith(
|
||||
expected_start="first_network_1/my_network_1/dense_1/",
|
||||
expected_start="first_network/my_network/dense/",
|
||||
actual=net1.trainable_weights[0].name)
|
||||
self.assertStartsWith(
|
||||
expected_start="first_network_1/my_network_2/dense_1/",
|
||||
expected_start="first_network/my_network_1/dense/",
|
||||
actual=net1.trainable_weights[1].name)
|
||||
self.assertStartsWith(
|
||||
expected_start="first_network_1/my_network_1/dense_1/",
|
||||
expected_start="first_network/my_network/dense/",
|
||||
actual=net2.trainable_weights[0].name)
|
||||
self.assertStartsWith(
|
||||
expected_start="second_network_1/my_network_1/dense_1/",
|
||||
expected_start="second_network/my_network/dense/",
|
||||
actual=net2.trainable_weights[1].name)
|
||||
self.assertTrue(net1.trainable_weights[0] is net2.trainable_weights[0])
|
||||
self.assertEqual("first_network_1", net1.name)
|
||||
self.assertEqual("my_network_1", net1.first.name)
|
||||
self.assertEqual("my_network_2", net1.second.name)
|
||||
self.assertEqual("first_network", net1.name)
|
||||
self.assertEqual("my_network", net1.first.name)
|
||||
self.assertEqual("my_network_1", net1.second.name)
|
||||
self.assertTrue(net2.first is net1.first)
|
||||
self.assertEqual("my_network_1", net2.second.name)
|
||||
self.assertEqual("my_network", net2.second.name)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testLayerCallInDifferentOrderThanConstruct(self):
|
||||
@ -1038,23 +1124,23 @@ class NetworkTest(test.TestCase):
|
||||
net1(one)
|
||||
|
||||
self.assertStartsWith(
|
||||
expected_start="first_network_1/dense_1/",
|
||||
expected_start="first_network/dense/",
|
||||
actual=net1.trainable_weights[0].name)
|
||||
self.assertStartsWith(
|
||||
expected_start="first_network_1/dense_2/",
|
||||
expected_start="first_network/dense_1/",
|
||||
actual=net1.trainable_weights[1].name)
|
||||
self.assertStartsWith(
|
||||
expected_start="first_network_1/dense_1/",
|
||||
expected_start="first_network/dense/",
|
||||
actual=net2.trainable_weights[0].name)
|
||||
self.assertStartsWith(
|
||||
expected_start="second_network_1/dense_1/",
|
||||
expected_start="second_network/dense/",
|
||||
actual=net2.trainable_weights[1].name)
|
||||
self.assertTrue(net1.trainable_weights[0] is net2.trainable_weights[0])
|
||||
self.assertEqual("first_network_1", net1.name)
|
||||
self.assertEqual("dense_1", net1.first.name)
|
||||
self.assertEqual("dense_2", net1.second.name)
|
||||
self.assertEqual("first_network", net1.name)
|
||||
self.assertEqual("dense", net1.first.name)
|
||||
self.assertEqual("dense_1", net1.second.name)
|
||||
self.assertTrue(net2.first is net1.first)
|
||||
self.assertEqual("dense_1", net2.second.name)
|
||||
self.assertEqual("dense", net2.second.name)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testLayerAlreadyBuilt(self):
|
||||
@ -1083,13 +1169,13 @@ class NetworkTest(test.TestCase):
|
||||
# do not match their layer names.
|
||||
actual=net.trainable_weights[0].name)
|
||||
self.assertStartsWith(
|
||||
expected_start="first_network_1/dense_1/",
|
||||
expected_start="first_network/dense/",
|
||||
actual=net.trainable_weights[1].name)
|
||||
self.assertTrue(
|
||||
net.trainable_weights[0] is shared_layer.trainable_weights[0])
|
||||
self.assertEqual("first_network_1", net.name)
|
||||
self.assertEqual("first_network", net.name)
|
||||
self.assertEqual("dense_3", net.first.name)
|
||||
self.assertEqual("dense_1", net.second.name)
|
||||
self.assertEqual("dense", net.second.name)
|
||||
|
||||
|
||||
class SequentialTest(test.TestCase):
|
||||
|
@ -30,9 +30,6 @@ To use, at program startup, call `tfe.enable_eager_execution()`.
|
||||
@@value_and_gradients_function
|
||||
@@GradientTape
|
||||
|
||||
@@enable_tracing
|
||||
@@flush_trace
|
||||
|
||||
@@run
|
||||
@@enable_eager_execution
|
||||
|
||||
@ -46,13 +43,16 @@ To use, at program startup, call `tfe.enable_eager_execution()`.
|
||||
@@seterr
|
||||
|
||||
@@Iterator
|
||||
@@Network
|
||||
@@Saver
|
||||
@@restore_variables_on_create
|
||||
@@Variable
|
||||
@@get_optimizer_variables
|
||||
@@EagerVariableStore
|
||||
|
||||
@@Network
|
||||
@@save_network_checkpoint
|
||||
@@restore_network_checkpoint
|
||||
|
||||
@@in_eager_mode
|
||||
@@in_graph_mode
|
||||
|
||||
@ -74,6 +74,8 @@ from __future__ import print_function
|
||||
from tensorflow.contrib.eager.python import metrics
|
||||
from tensorflow.contrib.eager.python.datasets import Iterator
|
||||
from tensorflow.contrib.eager.python.network import Network
|
||||
from tensorflow.contrib.eager.python.network import save_network_checkpoint
|
||||
from tensorflow.contrib.eager.python.network import restore_network_checkpoint
|
||||
from tensorflow.contrib.eager.python.saver import get_optimizer_variables
|
||||
from tensorflow.contrib.eager.python.saver import restore_variables_on_create
|
||||
from tensorflow.contrib.eager.python.saver import Saver
|
||||
@ -86,7 +88,6 @@ from tensorflow.python.eager.context import in_eager_mode
|
||||
from tensorflow.python.eager.context import in_graph_mode
|
||||
from tensorflow.python.eager.context import list_devices
|
||||
from tensorflow.python.eager.context import num_gpus
|
||||
from tensorflow.python.eager.core import enable_tracing
|
||||
from tensorflow.python.eager.custom_gradient import custom_gradient
|
||||
from tensorflow.python.eager.execution_callbacks import add_execution_callback
|
||||
from tensorflow.python.eager.execution_callbacks import clear_execution_callbacks
|
||||
|
@ -208,6 +208,7 @@ py_library(
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:metrics",
|
||||
"//tensorflow/python:summary",
|
||||
"//tensorflow/python/estimator:head",
|
||||
"//tensorflow/python/estimator:metric_keys",
|
||||
|
@ -27,6 +27,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import metrics as metrics_lib
|
||||
from tensorflow.python.saved_model import signature_constants
|
||||
from tensorflow.python.summary import summary
|
||||
|
||||
@ -342,14 +343,19 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access
|
||||
predictions = {}
|
||||
metrics = {}
|
||||
losses = []
|
||||
for head, spec in zip(self._heads, all_estimator_spec):
|
||||
losses.append(spec.loss)
|
||||
head_name = head.name
|
||||
# Metric keys already contain head.name.
|
||||
metrics.update(spec.eval_metric_ops or {})
|
||||
for k, v in six.iteritems(spec.predictions):
|
||||
predictions[(head_name, k)] = v
|
||||
loss = _merge_losses(losses, self._head_weights)
|
||||
with ops.name_scope('merge_eval'):
|
||||
for head, spec in zip(self._heads, all_estimator_spec):
|
||||
losses.append(spec.loss)
|
||||
head_name = head.name
|
||||
# Loss metric is not added by default.
|
||||
loss_name = head_lib._summary_key( # pylint:disable=protected-access
|
||||
head_name, metric_keys.MetricKeys.LOSS)
|
||||
metrics[loss_name] = metrics_lib.mean(spec.loss, name=loss_name)
|
||||
# Metric keys already contain head.name.
|
||||
metrics.update(spec.eval_metric_ops or {})
|
||||
for k, v in six.iteritems(spec.predictions):
|
||||
predictions[(head_name, k)] = v
|
||||
loss = _merge_losses(losses, self._head_weights)
|
||||
|
||||
return model_fn.EstimatorSpec(
|
||||
mode=model_fn.ModeKeys.EVAL,
|
||||
|
@ -297,6 +297,8 @@ class MultiHeadTest(test.TestCase):
|
||||
|
||||
keys = metric_keys.MetricKeys
|
||||
expected_metrics = {
|
||||
keys.LOSS + '/head1': expected_loss_head1,
|
||||
keys.LOSS + '/head2': expected_loss_head2,
|
||||
# Average loss over examples.
|
||||
keys.LOSS_MEAN + '/head1': expected_loss_head1 / 2,
|
||||
keys.LOSS_MEAN + '/head2': expected_loss_head2 / 2,
|
||||
|
@ -19,7 +19,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.factorization.python.ops import factorization_ops
|
||||
from tensorflow.contrib.framework.python.ops import variables as framework_variables
|
||||
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||
from tensorflow.contrib.learn.python.learn.estimators import model_fn
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -32,175 +31,64 @@ from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.summary import summary
|
||||
from tensorflow.python.training import session_run_hook
|
||||
from tensorflow.python.training import training_util
|
||||
|
||||
|
||||
class _SweepHook(session_run_hook.SessionRunHook):
|
||||
"""Keeps track of row/col sweeps, and runs prep ops before each sweep."""
|
||||
|
||||
def __init__(self, is_row_sweep_var, train_ops, num_rows, num_cols,
|
||||
input_row_indices, input_col_indices, row_prep_ops,
|
||||
col_prep_ops, init_op, completed_sweeps_var):
|
||||
def __init__(self, is_row_sweep_var, is_sweep_done_var, init_op,
|
||||
row_prep_ops, col_prep_ops, row_train_op, col_train_op,
|
||||
switch_op):
|
||||
"""Initializes SweepHook.
|
||||
|
||||
Args:
|
||||
is_row_sweep_var: A Boolean tf.Variable, determines whether we are
|
||||
currently doing a row or column sweep. It is updated by the hook.
|
||||
train_ops: A list of ops. The ops created by this hook will have
|
||||
control dependencies on `train_ops`.
|
||||
num_rows: int, the total number of rows to be processed.
|
||||
num_cols: int, the total number of columns to be processed.
|
||||
input_row_indices: A Tensor of type int64. The indices of the input rows
|
||||
that are processed during the current sweep. All elements of
|
||||
`input_row_indices` must be in [0, num_rows).
|
||||
input_col_indices: A Tensor of type int64. The indices of the input
|
||||
columns that are processed during the current sweep. All elements of
|
||||
`input_col_indices` must be in [0, num_cols).
|
||||
row_prep_ops: list of ops, to be run before the beginning of each row
|
||||
sweep, in the given order.
|
||||
col_prep_ops: list of ops, to be run before the beginning of each column
|
||||
sweep, in the given order.
|
||||
is_sweep_done_var: A Boolean tf.Variable, determines whether we are
|
||||
starting a new sweep (this is used to determine when to run the prep ops
|
||||
below).
|
||||
init_op: op to be run once before training. This is typically a local
|
||||
initialization op (such as cache initialization).
|
||||
completed_sweeps_var: An integer tf.Variable, indicates the number of
|
||||
completed sweeps. It is updated by the hook.
|
||||
row_prep_ops: A list of TensorFlow ops, to be run before the beginning of
|
||||
each row sweep (and during initialization), in the given order.
|
||||
col_prep_ops: A list of TensorFlow ops, to be run before the beginning of
|
||||
each column sweep (and during initialization), in the given order.
|
||||
row_train_op: A TensorFlow op to be run during row sweeps.
|
||||
col_train_op: A TensorFlow op to be run during column sweeps.
|
||||
switch_op: A TensorFlow op to be run before each sweep.
|
||||
"""
|
||||
self._num_rows = num_rows
|
||||
self._num_cols = num_cols
|
||||
self._is_row_sweep_var = is_row_sweep_var
|
||||
self._is_sweep_done_var = is_sweep_done_var
|
||||
self._init_op = init_op
|
||||
self._row_prep_ops = row_prep_ops
|
||||
self._col_prep_ops = col_prep_ops
|
||||
self._init_op = init_op
|
||||
self._is_row_sweep_var = is_row_sweep_var
|
||||
self._completed_sweeps_var = completed_sweeps_var
|
||||
# Boolean variable that determines whether the init_ops have been run.
|
||||
self._row_train_op = row_train_op
|
||||
self._col_train_op = col_train_op
|
||||
self._switch_op = switch_op
|
||||
# Boolean variable that determines whether the init_op has been run.
|
||||
self._is_initialized = False
|
||||
# Ops to run jointly with train_ops, responsible for updating
|
||||
# `is_row_sweep_var` and incrementing the `global_step` and
|
||||
# `completed_sweeps` counters.
|
||||
self._update_op, self._is_sweep_done_var, self._switch_op = (
|
||||
self._create_hook_ops(input_row_indices, input_col_indices, train_ops))
|
||||
|
||||
def _create_hook_ops(self, input_row_indices, input_col_indices, train_ops):
|
||||
"""Creates ops to update is_row_sweep_var, global_step and completed_sweeps.
|
||||
|
||||
Creates two boolean tensors `processed_rows` and `processed_cols`, which
|
||||
keep track of which rows/cols have been processed during the current sweep.
|
||||
Returns ops that should be run after each row / col update.
|
||||
- When `self._is_row_sweep_var` is True, it sets
|
||||
processed_rows[input_row_indices] to True.
|
||||
- When `self._is_row_sweep_var` is False, it sets
|
||||
processed_cols[input_col_indices] to True.
|
||||
|
||||
Args:
|
||||
input_row_indices: A Tensor. The indices of the input rows that are
|
||||
processed during the current sweep.
|
||||
input_col_indices: A Tensor. The indices of the input columns that
|
||||
are processed during the current sweep.
|
||||
train_ops: A list of ops. The ops created by this function have control
|
||||
dependencies on `train_ops`.
|
||||
|
||||
Returns:
|
||||
A tuple consisting of:
|
||||
update_op: An op to be run jointly with training. It updates the state
|
||||
and increments counters (global step and completed sweeps).
|
||||
is_sweep_done_var: A Boolean tf.Variable, specifies whether the sweep is
|
||||
done, i.e. all rows (during a row sweep) or all columns (during a
|
||||
column sweep) have been processed.
|
||||
switch_op: An op to be run in `self.before_run` when the sweep is done.
|
||||
"""
|
||||
processed_rows_init = array_ops.fill(dims=[self._num_rows], value=False)
|
||||
with ops.colocate_with(processed_rows_init):
|
||||
processed_rows = variable_scope.variable(
|
||||
processed_rows_init,
|
||||
collections=[ops.GraphKeys.GLOBAL_VARIABLES],
|
||||
trainable=False,
|
||||
name="sweep_hook_processed_rows")
|
||||
processed_cols_init = array_ops.fill(dims=[self._num_cols], value=False)
|
||||
with ops.colocate_with(processed_cols_init):
|
||||
processed_cols = variable_scope.variable(
|
||||
processed_cols_init,
|
||||
collections=[ops.GraphKeys.GLOBAL_VARIABLES],
|
||||
trainable=False,
|
||||
name="sweep_hook_processed_cols")
|
||||
switch_ops = control_flow_ops.group(
|
||||
state_ops.assign(
|
||||
self._is_row_sweep_var,
|
||||
math_ops.logical_not(self._is_row_sweep_var)),
|
||||
state_ops.assign(processed_rows, processed_rows_init),
|
||||
state_ops.assign(processed_cols, processed_cols_init))
|
||||
is_sweep_done_var = variable_scope.variable(
|
||||
False,
|
||||
collections=[ops.GraphKeys.GLOBAL_VARIABLES],
|
||||
trainable=False,
|
||||
name="is_sweep_done")
|
||||
|
||||
# After running the `train_ops`, updates `processed_rows` or
|
||||
# `processed_cols` tensors, depending on whether this is a row or col sweep.
|
||||
with ops.control_dependencies(train_ops):
|
||||
with ops.colocate_with(processed_rows):
|
||||
update_processed_rows = state_ops.scatter_update(
|
||||
processed_rows,
|
||||
input_row_indices,
|
||||
math_ops.logical_and(
|
||||
self._is_row_sweep_var,
|
||||
array_ops.ones_like(input_row_indices, dtype=dtypes.bool)))
|
||||
with ops.colocate_with(processed_cols):
|
||||
update_processed_cols = state_ops.scatter_update(
|
||||
processed_cols,
|
||||
input_col_indices,
|
||||
math_ops.logical_and(
|
||||
math_ops.logical_not(self._is_row_sweep_var),
|
||||
array_ops.ones_like(input_col_indices, dtype=dtypes.bool)))
|
||||
update_processed_op = control_flow_ops.group(
|
||||
update_processed_rows, update_processed_cols)
|
||||
|
||||
with ops.control_dependencies([update_processed_op]):
|
||||
is_sweep_done = math_ops.logical_or(
|
||||
math_ops.reduce_all(processed_rows),
|
||||
math_ops.reduce_all(processed_cols))
|
||||
# Increments global step.
|
||||
global_step = framework_variables.get_global_step()
|
||||
if global_step is not None:
|
||||
global_step_incr_op = state_ops.assign_add(
|
||||
global_step, 1, name="global_step_incr").op
|
||||
else:
|
||||
global_step_incr_op = control_flow_ops.no_op()
|
||||
# Increments completed sweeps.
|
||||
completed_sweeps_incr_op = state_ops.assign_add(
|
||||
self._completed_sweeps_var,
|
||||
math_ops.cast(is_sweep_done, dtypes.int32),
|
||||
use_locking=True).op
|
||||
update_ops = control_flow_ops.group(
|
||||
global_step_incr_op,
|
||||
completed_sweeps_incr_op,
|
||||
state_ops.assign(is_sweep_done_var, is_sweep_done))
|
||||
|
||||
return update_ops, is_sweep_done_var, switch_ops
|
||||
|
||||
def before_run(self, run_context):
|
||||
"""Runs the appropriate prep ops, and requests running update ops."""
|
||||
# Runs the appropriate init ops and prep ops.
|
||||
sess = run_context.session
|
||||
is_sweep_done = sess.run(self._is_sweep_done_var)
|
||||
if not self._is_initialized:
|
||||
logging.info("SweepHook running cache init op.")
|
||||
logging.info("SweepHook running init op.")
|
||||
sess.run(self._init_op)
|
||||
if is_sweep_done:
|
||||
sess.run(self._switch_op)
|
||||
is_row_sweep = sess.run(self._is_row_sweep_var)
|
||||
if is_sweep_done or not self._is_initialized:
|
||||
logging.info("SweepHook running sweep prep ops.")
|
||||
row_sweep = sess.run(self._is_row_sweep_var)
|
||||
prep_ops = self._row_prep_ops if row_sweep else self._col_prep_ops
|
||||
logging.info("SweepHook running prep ops for the {} sweep.".format(
|
||||
"row" if is_row_sweep else "col"))
|
||||
prep_ops = self._row_prep_ops if is_row_sweep else self._col_prep_ops
|
||||
for prep_op in prep_ops:
|
||||
sess.run(prep_op)
|
||||
|
||||
self._is_initialized = True
|
||||
|
||||
# Requests running `self._update_op` jointly with the training op.
|
||||
logging.info("Next fit step starting.")
|
||||
return session_run_hook.SessionRunArgs(fetches=[self._update_op])
|
||||
|
||||
def after_run(self, run_context, run_values):
|
||||
logging.info("Fit step done.")
|
||||
return session_run_hook.SessionRunArgs(
|
||||
fetches=[self._row_train_op if is_row_sweep else self._col_train_op])
|
||||
|
||||
|
||||
class _StopAtSweepHook(session_run_hook.SessionRunHook):
|
||||
@ -246,6 +134,9 @@ def _wals_factorization_model_function(features, labels, mode, params):
|
||||
|
||||
Returns:
|
||||
A ModelFnOps object.
|
||||
|
||||
Raises:
|
||||
ValueError: If `mode` is not recognized.
|
||||
"""
|
||||
assert labels is None
|
||||
use_factors_weights_cache = (params["use_factors_weights_cache_for_training"]
|
||||
@ -269,86 +160,156 @@ def _wals_factorization_model_function(features, labels, mode, params):
|
||||
use_gramian_cache=use_gramian_cache)
|
||||
|
||||
# Get input rows and cols. We either update rows or columns depending on
|
||||
# the value of row_sweep, which is maintained using a session hook
|
||||
# the value of row_sweep, which is maintained using a session hook.
|
||||
input_rows = features[WALSMatrixFactorization.INPUT_ROWS]
|
||||
input_cols = features[WALSMatrixFactorization.INPUT_COLS]
|
||||
input_row_indices, _ = array_ops.unique(input_rows.indices[:, 0])
|
||||
input_col_indices, _ = array_ops.unique(input_cols.indices[:, 0])
|
||||
|
||||
# Train ops, controlled using the SweepHook
|
||||
# We need to run the following ops:
|
||||
# Before a row sweep:
|
||||
# row_update_prep_gramian_op
|
||||
# initialize_row_update_op
|
||||
# During a row sweep:
|
||||
# update_row_factors_op
|
||||
# Before a col sweep:
|
||||
# col_update_prep_gramian_op
|
||||
# initialize_col_update_op
|
||||
# During a col sweep:
|
||||
# update_col_factors_op
|
||||
# TRAIN mode:
|
||||
if mode == model_fn.ModeKeys.TRAIN:
|
||||
# Training consists of the folowing ops (controlled using a SweepHook).
|
||||
# Before a row sweep:
|
||||
# row_update_prep_gramian_op
|
||||
# initialize_row_update_op
|
||||
# During a row sweep:
|
||||
# update_row_factors_op
|
||||
# Before a col sweep:
|
||||
# col_update_prep_gramian_op
|
||||
# initialize_col_update_op
|
||||
# During a col sweep:
|
||||
# update_col_factors_op
|
||||
|
||||
is_row_sweep_var = variable_scope.variable(
|
||||
True,
|
||||
trainable=False,
|
||||
name="is_row_sweep",
|
||||
collections=[ops.GraphKeys.GLOBAL_VARIABLES])
|
||||
completed_sweeps_var = variable_scope.variable(
|
||||
0,
|
||||
trainable=False,
|
||||
name=WALSMatrixFactorization.COMPLETED_SWEEPS,
|
||||
collections=[ops.GraphKeys.GLOBAL_VARIABLES])
|
||||
is_row_sweep_var = variable_scope.variable(
|
||||
True,
|
||||
trainable=False,
|
||||
name="is_row_sweep",
|
||||
collections=[ops.GraphKeys.GLOBAL_VARIABLES])
|
||||
is_sweep_done_var = variable_scope.variable(
|
||||
False,
|
||||
trainable=False,
|
||||
name="is_sweep_done",
|
||||
collections=[ops.GraphKeys.GLOBAL_VARIABLES])
|
||||
completed_sweeps_var = variable_scope.variable(
|
||||
0,
|
||||
trainable=False,
|
||||
name=WALSMatrixFactorization.COMPLETED_SWEEPS,
|
||||
collections=[ops.GraphKeys.GLOBAL_VARIABLES])
|
||||
loss_var = variable_scope.variable(
|
||||
0.,
|
||||
trainable=False,
|
||||
name=WALSMatrixFactorization.LOSS,
|
||||
collections=[ops.GraphKeys.GLOBAL_VARIABLES])
|
||||
# The root weighted squared error =
|
||||
# \sqrt( \sum_{i,j} w_ij * (a_ij - r_ij)^2 / \sum_{i,j} w_ij )
|
||||
rwse_var = variable_scope.variable(
|
||||
0.,
|
||||
trainable=False,
|
||||
name=WALSMatrixFactorization.RWSE,
|
||||
collections=[ops.GraphKeys.GLOBAL_VARIABLES])
|
||||
|
||||
# The row sweep is determined by is_row_sweep_var (controlled by the
|
||||
# sweep_hook) in TRAIN mode, and manually in EVAL mode.
|
||||
is_row_sweep = (features[WALSMatrixFactorization.PROJECT_ROW]
|
||||
if mode == model_fn.ModeKeys.EVAL else is_row_sweep_var)
|
||||
summary.scalar("loss", loss_var)
|
||||
summary.scalar("root_weighted_squared_error", rwse_var)
|
||||
summary.scalar("completed_sweeps", completed_sweeps_var)
|
||||
|
||||
def update_row_factors():
|
||||
return model.update_row_factors(sp_input=input_rows, transpose_input=False)
|
||||
# Increments global step.
|
||||
global_step = training_util.get_global_step()
|
||||
if global_step:
|
||||
global_step_incr_op = state_ops.assign_add(
|
||||
global_step, 1, name="global_step_incr").op
|
||||
else:
|
||||
global_step_incr_op = control_flow_ops.no_op()
|
||||
|
||||
def update_col_factors():
|
||||
return model.update_col_factors(sp_input=input_cols, transpose_input=True)
|
||||
def create_axis_ops(sp_input, num_items, update_fn, axis_name):
|
||||
"""Creates book-keeping and training ops for a given axis.
|
||||
|
||||
(_, train_op,
|
||||
unregularized_loss, regularization, sum_weights) = control_flow_ops.cond(
|
||||
is_row_sweep, update_row_factors, update_col_factors)
|
||||
loss = unregularized_loss + regularization
|
||||
root_weighted_squared_error = math_ops.sqrt(unregularized_loss / sum_weights)
|
||||
Args:
|
||||
sp_input: A SparseTensor corresponding to the row or column batch.
|
||||
num_items: An integer, the total number of items of this axis.
|
||||
update_fn: A function that takes one argument (`sp_input`), and that
|
||||
returns a tuple of
|
||||
* new_factors: A flot Tensor of the factor values after update.
|
||||
* update_op: a TensorFlow op which updates the factors.
|
||||
* loss: A float Tensor, the unregularized loss.
|
||||
* reg_loss: A float Tensor, the regularization loss.
|
||||
* sum_weights: A float Tensor, the sum of factor weights.
|
||||
axis_name: A string that specifies the name of the axis.
|
||||
|
||||
row_prep_ops = [
|
||||
model.row_update_prep_gramian_op, model.initialize_row_update_op
|
||||
]
|
||||
col_prep_ops = [
|
||||
model.col_update_prep_gramian_op, model.initialize_col_update_op
|
||||
]
|
||||
init_ops = [model.worker_init]
|
||||
Returns:
|
||||
A tuple consisting of:
|
||||
* reset_processed_items_op: A TensorFlow op, to be run before the
|
||||
beginning of any sweep. It marks all items as not-processed.
|
||||
* axis_train_op: A Tensorflow op, to be run during this axis' sweeps.
|
||||
"""
|
||||
processed_items_init = array_ops.fill(dims=[num_items], value=False)
|
||||
with ops.colocate_with(processed_items_init):
|
||||
processed_items = variable_scope.variable(
|
||||
processed_items_init,
|
||||
collections=[ops.GraphKeys.GLOBAL_VARIABLES],
|
||||
trainable=False,
|
||||
name="processed_" + axis_name)
|
||||
reset_processed_items_op = state_ops.assign(
|
||||
processed_items, processed_items_init,
|
||||
name="reset_processed_" + axis_name)
|
||||
_, update_op, loss, reg, sum_weights = update_fn(sp_input)
|
||||
input_indices = sp_input.indices[:, 0]
|
||||
with ops.control_dependencies([
|
||||
update_op,
|
||||
state_ops.assign(loss_var, loss + reg),
|
||||
state_ops.assign(rwse_var, math_ops.sqrt(loss / sum_weights))]):
|
||||
with ops.colocate_with(processed_items):
|
||||
update_processed_items = state_ops.scatter_update(
|
||||
processed_items,
|
||||
input_indices,
|
||||
array_ops.ones_like(input_indices, dtype=dtypes.bool),
|
||||
name="update_processed_{}_indices".format(axis_name))
|
||||
with ops.control_dependencies([update_processed_items]):
|
||||
is_sweep_done = math_ops.reduce_all(processed_items)
|
||||
axis_train_op = control_flow_ops.group(
|
||||
global_step_incr_op,
|
||||
state_ops.assign(is_sweep_done_var, is_sweep_done),
|
||||
state_ops.assign_add(
|
||||
completed_sweeps_var,
|
||||
math_ops.cast(is_sweep_done, dtypes.int32)),
|
||||
name="{}_sweep_train_op".format(axis_name))
|
||||
return reset_processed_items_op, axis_train_op
|
||||
|
||||
sweep_hook = _SweepHook(
|
||||
is_row_sweep_var,
|
||||
[train_op, loss],
|
||||
params["num_rows"],
|
||||
params["num_cols"],
|
||||
input_row_indices,
|
||||
input_col_indices,
|
||||
row_prep_ops,
|
||||
col_prep_ops,
|
||||
init_ops,
|
||||
completed_sweeps_var)
|
||||
training_hooks = [sweep_hook]
|
||||
if max_sweeps is not None:
|
||||
training_hooks.append(_StopAtSweepHook(max_sweeps))
|
||||
reset_processed_rows_op, row_train_op = create_axis_ops(
|
||||
input_rows,
|
||||
params["num_rows"],
|
||||
lambda x: model.update_row_factors(sp_input=x, transpose_input=False),
|
||||
"rows")
|
||||
reset_processed_cols_op, col_train_op = create_axis_ops(
|
||||
input_cols,
|
||||
params["num_cols"],
|
||||
lambda x: model.update_col_factors(sp_input=x, transpose_input=True),
|
||||
"cols")
|
||||
switch_op = control_flow_ops.group(
|
||||
state_ops.assign(
|
||||
is_row_sweep_var, math_ops.logical_not(is_row_sweep_var)),
|
||||
reset_processed_rows_op,
|
||||
reset_processed_cols_op,
|
||||
name="sweep_switch_op")
|
||||
row_prep_ops = [
|
||||
model.row_update_prep_gramian_op, model.initialize_row_update_op]
|
||||
col_prep_ops = [
|
||||
model.col_update_prep_gramian_op, model.initialize_col_update_op]
|
||||
init_op = model.worker_init
|
||||
sweep_hook = _SweepHook(
|
||||
is_row_sweep_var, is_sweep_done_var, init_op,
|
||||
row_prep_ops, col_prep_ops, row_train_op, col_train_op, switch_op)
|
||||
training_hooks = [sweep_hook]
|
||||
if max_sweeps is not None:
|
||||
training_hooks.append(_StopAtSweepHook(max_sweeps))
|
||||
|
||||
# The root weighted squared error =
|
||||
# \sqrt( \sum_{i,j} w_ij * (a_ij - r_ij)^2 / \sum_{i,j} w_ij )
|
||||
summary.scalar("loss", loss) # the estimated total training loss
|
||||
summary.scalar("root_weighted_squared_error", root_weighted_squared_error)
|
||||
summary.scalar("completed_sweeps", completed_sweeps_var)
|
||||
return model_fn.ModelFnOps(
|
||||
mode=model_fn.ModeKeys.TRAIN,
|
||||
predictions={},
|
||||
loss=loss_var,
|
||||
eval_metric_ops={},
|
||||
train_op=control_flow_ops.no_op(),
|
||||
training_hooks=training_hooks)
|
||||
|
||||
# Prediction ops (only return predictions in INFER mode)
|
||||
predictions = {}
|
||||
if mode == model_fn.ModeKeys.INFER:
|
||||
project_row = features[WALSMatrixFactorization.PROJECT_ROW]
|
||||
# INFER mode
|
||||
elif mode == model_fn.ModeKeys.INFER:
|
||||
projection_weights = features.get(
|
||||
WALSMatrixFactorization.PROJECTION_WEIGHTS)
|
||||
|
||||
@ -364,17 +325,45 @@ def _wals_factorization_model_function(features, labels, mode, params):
|
||||
projection_weights=projection_weights,
|
||||
transpose_input=True)
|
||||
|
||||
predictions[WALSMatrixFactorization.PROJECTION_RESULT] = (
|
||||
control_flow_ops.cond(project_row, get_row_projection,
|
||||
get_col_projection))
|
||||
predictions = {
|
||||
WALSMatrixFactorization.PROJECTION_RESULT: control_flow_ops.cond(
|
||||
features[WALSMatrixFactorization.PROJECT_ROW],
|
||||
get_row_projection,
|
||||
get_col_projection)
|
||||
}
|
||||
|
||||
return model_fn.ModelFnOps(
|
||||
mode=mode,
|
||||
predictions=predictions,
|
||||
loss=loss,
|
||||
eval_metric_ops={},
|
||||
train_op=train_op,
|
||||
training_hooks=training_hooks)
|
||||
return model_fn.ModelFnOps(
|
||||
mode=model_fn.ModeKeys.INFER,
|
||||
predictions=predictions,
|
||||
loss=None,
|
||||
eval_metric_ops={},
|
||||
train_op=control_flow_ops.no_op(),
|
||||
training_hooks=[])
|
||||
|
||||
# EVAL mode
|
||||
elif mode == model_fn.ModeKeys.EVAL:
|
||||
def get_row_loss():
|
||||
_, _, loss, reg, _ = model.update_row_factors(
|
||||
sp_input=input_rows, transpose_input=False)
|
||||
return loss + reg
|
||||
def get_col_loss():
|
||||
_, _, loss, reg, _ = model.update_col_factors(
|
||||
sp_input=input_cols, transpose_input=True)
|
||||
return loss + reg
|
||||
loss = control_flow_ops.cond(
|
||||
features[WALSMatrixFactorization.PROJECT_ROW],
|
||||
get_row_loss,
|
||||
get_col_loss)
|
||||
return model_fn.ModelFnOps(
|
||||
mode=model_fn.ModeKeys.EVAL,
|
||||
predictions={},
|
||||
loss=loss,
|
||||
eval_metric_ops={},
|
||||
train_op=control_flow_ops.no_op(),
|
||||
training_hooks=[])
|
||||
|
||||
else:
|
||||
raise ValueError("mode=%s is not recognized." % str(mode))
|
||||
|
||||
|
||||
class WALSMatrixFactorization(estimator.Estimator):
|
||||
@ -452,6 +441,10 @@ class WALSMatrixFactorization(estimator.Estimator):
|
||||
PROJECTION_RESULT = "projection"
|
||||
# Name of the completed_sweeps variable
|
||||
COMPLETED_SWEEPS = "completed_sweeps"
|
||||
# Name of the loss variable
|
||||
LOSS = "WALS_loss"
|
||||
# Name of the Root Weighted Squared Error variable
|
||||
RWSE = "WALS_RWSE"
|
||||
|
||||
def __init__(self,
|
||||
num_rows,
|
||||
|
@ -417,73 +417,67 @@ class WALSMatrixFactorizationUnsupportedTest(test.TestCase):
|
||||
|
||||
class SweepHookTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self._num_rows = 5
|
||||
self._num_cols = 7
|
||||
self._train_op = control_flow_ops.no_op()
|
||||
self._row_prep_done = variables.Variable(False)
|
||||
self._col_prep_done = variables.Variable(False)
|
||||
self._init_done = variables.Variable(False)
|
||||
self._row_prep_ops = [state_ops.assign(self._row_prep_done, True)]
|
||||
self._col_prep_ops = [state_ops.assign(self._col_prep_done, True)]
|
||||
self._init_ops = [state_ops.assign(self._init_done, True)]
|
||||
self._input_row_indices_ph = array_ops.placeholder(dtypes.int64)
|
||||
self._input_col_indices_ph = array_ops.placeholder(dtypes.int64)
|
||||
|
||||
def test_sweeps(self):
|
||||
def ind_feed(row_indices, col_indices):
|
||||
return {
|
||||
self._input_row_indices_ph: row_indices,
|
||||
self._input_col_indices_ph: col_indices
|
||||
}
|
||||
is_row_sweep_var = variables.Variable(True)
|
||||
is_sweep_done_var = variables.Variable(False)
|
||||
init_done = variables.Variable(False)
|
||||
row_prep_done = variables.Variable(False)
|
||||
col_prep_done = variables.Variable(False)
|
||||
row_train_done = variables.Variable(False)
|
||||
col_train_done = variables.Variable(False)
|
||||
|
||||
init_op = state_ops.assign(init_done, True)
|
||||
row_prep_op = state_ops.assign(row_prep_done, True)
|
||||
col_prep_op = state_ops.assign(col_prep_done, True)
|
||||
row_train_op = state_ops.assign(row_train_done, True)
|
||||
col_train_op = state_ops.assign(col_train_done, True)
|
||||
train_op = control_flow_ops.no_op()
|
||||
switch_op = control_flow_ops.group(
|
||||
state_ops.assign(is_sweep_done_var, False),
|
||||
state_ops.assign(is_row_sweep_var,
|
||||
math_ops.logical_not(is_row_sweep_var)))
|
||||
mark_sweep_done = state_ops.assign(is_sweep_done_var, True)
|
||||
|
||||
with self.test_session() as sess:
|
||||
is_row_sweep_var = variables.Variable(True)
|
||||
completed_sweeps_var = variables.Variable(0)
|
||||
sweep_hook = wals_lib._SweepHook(
|
||||
is_row_sweep_var,
|
||||
[self._train_op],
|
||||
self._num_rows,
|
||||
self._num_cols,
|
||||
self._input_row_indices_ph,
|
||||
self._input_col_indices_ph,
|
||||
self._row_prep_ops,
|
||||
self._col_prep_ops,
|
||||
self._init_ops,
|
||||
completed_sweeps_var)
|
||||
is_sweep_done_var,
|
||||
init_op,
|
||||
[row_prep_op],
|
||||
[col_prep_op],
|
||||
row_train_op,
|
||||
col_train_op,
|
||||
switch_op)
|
||||
mon_sess = monitored_session._HookedSession(sess, [sweep_hook])
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
|
||||
# Init ops should run before the first run. Row sweep not completed.
|
||||
mon_sess.run(self._train_op, ind_feed([0, 1, 2], []))
|
||||
self.assertTrue(sess.run(self._init_done),
|
||||
msg='init ops not run by the sweep_hook')
|
||||
self.assertTrue(sess.run(self._row_prep_done),
|
||||
msg='row_prep not run by the sweep_hook')
|
||||
self.assertTrue(sess.run(is_row_sweep_var),
|
||||
msg='Row sweep is not complete but is_row_sweep is '
|
||||
'False.')
|
||||
# Row sweep completed.
|
||||
mon_sess.run(self._train_op, ind_feed([3, 4], [0, 1, 2, 3, 4, 5, 6]))
|
||||
self.assertTrue(sess.run(completed_sweeps_var) == 1,
|
||||
msg='Completed sweeps should be equal to 1.')
|
||||
self.assertTrue(sess.run(sweep_hook._is_sweep_done_var),
|
||||
msg='Sweep is complete but is_sweep_done is False.')
|
||||
# Col init ops should run. Col sweep not completed.
|
||||
mon_sess.run(self._train_op, ind_feed([], [0, 1, 2, 3, 4]))
|
||||
self.assertTrue(sess.run(self._col_prep_done),
|
||||
msg='col_prep not run by the sweep_hook')
|
||||
self.assertFalse(sess.run(is_row_sweep_var),
|
||||
msg='Col sweep is not complete but is_row_sweep is '
|
||||
'True.')
|
||||
self.assertFalse(sess.run(sweep_hook._is_sweep_done_var),
|
||||
msg='Sweep is not complete but is_sweep_done is True.')
|
||||
# Col sweep completed.
|
||||
mon_sess.run(self._train_op, ind_feed([], [4, 5, 6]))
|
||||
self.assertTrue(sess.run(sweep_hook._is_sweep_done_var),
|
||||
msg='Sweep is complete but is_sweep_done is False.')
|
||||
self.assertTrue(sess.run(completed_sweeps_var) == 2,
|
||||
msg='Completed sweeps should be equal to 2.')
|
||||
# Row sweep.
|
||||
mon_sess.run(train_op)
|
||||
self.assertTrue(sess.run(init_done),
|
||||
msg='init op not run by the Sweephook')
|
||||
self.assertTrue(sess.run(row_prep_done),
|
||||
msg='row_prep_op not run by the SweepHook')
|
||||
self.assertTrue(sess.run(row_train_done),
|
||||
msg='row_train_op not run by the SweepHook')
|
||||
self.assertTrue(
|
||||
sess.run(is_row_sweep_var),
|
||||
msg='Row sweep is not complete but is_row_sweep_var is False.')
|
||||
# Col sweep.
|
||||
mon_sess.run(mark_sweep_done)
|
||||
mon_sess.run(train_op)
|
||||
self.assertTrue(sess.run(col_prep_done),
|
||||
msg='col_prep_op not run by the SweepHook')
|
||||
self.assertTrue(sess.run(col_train_done),
|
||||
msg='col_train_op not run by the SweepHook')
|
||||
self.assertFalse(
|
||||
sess.run(is_row_sweep_var),
|
||||
msg='Col sweep is not complete but is_row_sweep_var is True.')
|
||||
# Row sweep.
|
||||
mon_sess.run(mark_sweep_done)
|
||||
mon_sess.run(train_op)
|
||||
self.assertTrue(
|
||||
sess.run(is_row_sweep_var),
|
||||
msg='Col sweep is complete but is_row_sweep_var is False.')
|
||||
|
||||
|
||||
class StopAtSweepHookTest(test.TestCase):
|
||||
|
@ -224,7 +224,8 @@ def transform(images, transforms, interpolation="NEAREST", name=None):
|
||||
`(x, y)` to a transformed *input* point
|
||||
`(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`,
|
||||
where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to
|
||||
the transform mapping input points to output points.
|
||||
the transform mapping input points to output points. Note that gradients
|
||||
are not backpropagated into transformation parameters.
|
||||
interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR".
|
||||
|
||||
Returns:
|
||||
|
@ -68,6 +68,7 @@ py_test(
|
||||
srcs = ["layer_collection_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/kfac/python/ops:fisher_blocks",
|
||||
"//tensorflow/contrib/kfac/python/ops:fisher_factors",
|
||||
"//tensorflow/contrib/kfac/python/ops:layer_collection",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -75,6 +76,7 @@ py_test(
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:linalg_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:random_ops",
|
||||
"//tensorflow/python:random_seed",
|
||||
"//tensorflow/python:variable_scope",
|
||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.kfac.python.ops import fisher_blocks
|
||||
from tensorflow.contrib.kfac.python.ops import fisher_factors
|
||||
from tensorflow.contrib.kfac.python.ops import layer_collection
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -25,6 +26,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import test
|
||||
@ -105,8 +107,10 @@ class LayerCollectionTest(test.TestCase):
|
||||
array_ops.constant(4), [1, 1, 1, 1], 'SAME',
|
||||
array_ops.ones((1, 1, 1, 1)), array_ops.constant(3))
|
||||
lc.register_conv2d(
|
||||
array_ops.constant(4), [1, 1, 1, 1], 'SAME',
|
||||
array_ops.ones((1, 1, 1, 1)), array_ops.constant(3),
|
||||
array_ops.constant(4), [1, 1, 1, 1],
|
||||
'SAME',
|
||||
array_ops.ones((1, 1, 1, 1)),
|
||||
array_ops.constant(3),
|
||||
approx=layer_collection.APPROX_DIAGONAL_NAME)
|
||||
lc.register_generic(
|
||||
array_ops.constant(5), 16, approx=layer_collection.APPROX_FULL_NAME)
|
||||
@ -122,8 +126,8 @@ class LayerCollectionTest(test.TestCase):
|
||||
random_seed.set_random_seed(200)
|
||||
lc = layer_collection.LayerCollection()
|
||||
key = array_ops.constant(1)
|
||||
lc.register_fully_connected(key,
|
||||
array_ops.constant(2), array_ops.constant(3))
|
||||
lc.register_fully_connected(key, array_ops.constant(2),
|
||||
array_ops.constant(3))
|
||||
with self.assertRaises(ValueError):
|
||||
lc.register_generic(key, 16)
|
||||
|
||||
@ -191,8 +195,8 @@ class LayerCollectionTest(test.TestCase):
|
||||
|
||||
lc.register_block((x, y), MockFisherBlock('foo'))
|
||||
self.assertEqual(
|
||||
set([MockFisherBlock('2'), MockFisherBlock('foo')]),
|
||||
set(lc.get_blocks()))
|
||||
set([MockFisherBlock('2'), MockFisherBlock('foo')]), set(
|
||||
lc.get_blocks()))
|
||||
|
||||
def testRegisterTupleVarSomeRegisteredInOtherTuples(self):
|
||||
x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
|
||||
@ -464,6 +468,66 @@ class LayerCollectionTest(test.TestCase):
|
||||
use_count_map = lc.get_use_count_map()
|
||||
self.assertDictEqual({'a': 4, 'b': 2, 'c': 4}, use_count_map)
|
||||
|
||||
def testIdentifyLinkedParametersSomeRegisteredInOtherTuples(self):
|
||||
x = variable_scope.get_variable('x', shape=())
|
||||
y = variable_scope.get_variable('y', shape=())
|
||||
z = variable_scope.get_variable('z', shape=())
|
||||
lc = layer_collection.LayerCollection()
|
||||
lc.define_linked_parameters((x, y))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
lc.define_linked_parameters((x, z))
|
||||
|
||||
def testIdentifySubsetPreviouslyRegisteredTensor(self):
|
||||
x = variable_scope.get_variable('x', shape=())
|
||||
y = variable_scope.get_variable('y', shape=())
|
||||
lc = layer_collection.LayerCollection()
|
||||
lc.define_linked_parameters((x, y))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
lc.define_linked_parameters(x)
|
||||
|
||||
def testSpecifyApproximation(self):
|
||||
w_0 = variable_scope.get_variable('w_0', [10, 10])
|
||||
w_1 = variable_scope.get_variable('w_1', [10, 10])
|
||||
|
||||
b_0 = variable_scope.get_variable('b_0', [10])
|
||||
b_1 = variable_scope.get_variable('b_1', [10])
|
||||
|
||||
x_0 = array_ops.placeholder(dtypes.float32, shape=(32, 10))
|
||||
x_1 = array_ops.placeholder(dtypes.float32, shape=(32, 10))
|
||||
|
||||
pre_bias_0 = math_ops.matmul(x_0, w_0)
|
||||
pre_bias_1 = math_ops.matmul(x_1, w_1)
|
||||
|
||||
# Build the fully connected layers in the graph.
|
||||
pre_bias_0 + b_0 # pylint: disable=pointless-statement
|
||||
pre_bias_1 + b_1 # pylint: disable=pointless-statement
|
||||
|
||||
lc = layer_collection.LayerCollection()
|
||||
lc.define_linked_parameters(
|
||||
w_0, approximation=layer_collection.APPROX_DIAGONAL_NAME)
|
||||
lc.define_linked_parameters(
|
||||
w_1, approximation=layer_collection.APPROX_DIAGONAL_NAME)
|
||||
lc.define_linked_parameters(
|
||||
b_0, approximation=layer_collection.APPROX_FULL_NAME)
|
||||
lc.define_linked_parameters(
|
||||
b_1, approximation=layer_collection.APPROX_FULL_NAME)
|
||||
|
||||
lc.register_fully_connected(w_0, x_0, pre_bias_0)
|
||||
lc.register_fully_connected(
|
||||
w_1, x_1, pre_bias_1, approx=layer_collection.APPROX_KRONECKER_NAME)
|
||||
self.assertIsInstance(lc.fisher_blocks[w_0],
|
||||
fisher_blocks.FullyConnectedDiagonalFB)
|
||||
self.assertIsInstance(lc.fisher_blocks[w_1],
|
||||
fisher_blocks.FullyConnectedKFACBasicFB)
|
||||
|
||||
lc.register_generic(b_0, batch_size=1)
|
||||
lc.register_generic(
|
||||
b_1, batch_size=1, approx=layer_collection.APPROX_DIAGONAL_NAME)
|
||||
self.assertIsInstance(lc.fisher_blocks[b_0], fisher_blocks.FullFB)
|
||||
self.assertIsInstance(lc.fisher_blocks[b_1], fisher_blocks.NaiveDiagonalFB)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -38,12 +38,26 @@ from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
# Names for various approximations that can be requested for Fisher blocks.
|
||||
APPROX_KRONECKER_NAME = "kron"
|
||||
APPROX_DIAGONAL_NAME = "diagonal"
|
||||
APPROX_FULL_NAME = "full"
|
||||
|
||||
_GENERIC_APPROX_TO_BLOCK_TYPES = {
|
||||
APPROX_FULL_NAME: fb.FullFB,
|
||||
APPROX_DIAGONAL_NAME: fb.NaiveDiagonalFB,
|
||||
}
|
||||
|
||||
_FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES = {
|
||||
APPROX_KRONECKER_NAME: fb.FullyConnectedKFACBasicFB,
|
||||
APPROX_DIAGONAL_NAME: fb.FullyConnectedDiagonalFB,
|
||||
}
|
||||
|
||||
_CONV2D_APPROX_TO_BLOCK_TYPES = {
|
||||
APPROX_KRONECKER_NAME: fb.ConvKFCBasicFB,
|
||||
APPROX_DIAGONAL_NAME: fb.ConvDiagonalFB,
|
||||
}
|
||||
|
||||
# Possible value for 'reuse' keyword argument. Sets 'reuse' to
|
||||
# tf.get_variable_scope().reuse.
|
||||
VARIABLE_SCOPE = "VARIABLE_SCOPE"
|
||||
@ -51,6 +65,14 @@ VARIABLE_SCOPE = "VARIABLE_SCOPE"
|
||||
# TODO(jamesmartens): need to add find_canonical_output back into this somewhere
|
||||
|
||||
|
||||
def ensure_sequence(obj):
|
||||
"""If `obj` isn't a tuple or list, return a tuple containing `obj`."""
|
||||
if isinstance(obj, (tuple, list)):
|
||||
return obj
|
||||
else:
|
||||
return (obj,)
|
||||
|
||||
|
||||
class LayerParametersDict(OrderedDict):
|
||||
"""An OrderedDict where keys are Tensors or tuples of Tensors.
|
||||
|
||||
@ -110,9 +132,14 @@ class LayerCollection(object):
|
||||
def __init__(self, graph=None, name="LayerCollection"):
|
||||
self.fisher_blocks = LayerParametersDict()
|
||||
self.fisher_factors = OrderedDict()
|
||||
self._linked_parameters = dict(
|
||||
) # dict mapping sets of variables to optionally specified approximations.
|
||||
self._graph = graph or ops.get_default_graph()
|
||||
self._loss_dict = {} # {str: LossFunction}
|
||||
self._subgraph = None
|
||||
self._default_generic_approximation = APPROX_FULL_NAME
|
||||
self._default_fully_connected_approximation = APPROX_KRONECKER_NAME
|
||||
self._default_convolution_2d_approximation = APPROX_KRONECKER_NAME
|
||||
|
||||
with variable_scope.variable_scope(None, default_name=name) as scope:
|
||||
self._var_scope = scope.name
|
||||
@ -122,6 +149,70 @@ class LayerCollection(object):
|
||||
"""LossFunctions registered with this LayerCollection."""
|
||||
return list(self._loss_dict.values())
|
||||
|
||||
def is_variable_registered(self, variable):
|
||||
"""Checks whether the variable has already been registered.
|
||||
|
||||
Args:
|
||||
variable: A single variable or tensor.
|
||||
Returns:
|
||||
True if the variable has been registered either by itself or as part of a
|
||||
tuple.
|
||||
"""
|
||||
return any([
|
||||
variable in key if isinstance(key, (tuple, list)) else variable == key
|
||||
for key in self.fisher_blocks.keys()
|
||||
])
|
||||
|
||||
@property
|
||||
def linked_parameters(self):
|
||||
"""Groups of parameters with an optionally specified approximation.
|
||||
|
||||
Linked parameters can be added using `define_linked_parameters`.
|
||||
If an approximation is specified, then this approximation will be used
|
||||
when registering a layer with exactly these parameters, unless an
|
||||
approximation is specified when calling the registration function.
|
||||
|
||||
Returns:
|
||||
A `dict` mapping tuples of parameters to an optional string.
|
||||
"""
|
||||
return self._linked_parameters
|
||||
|
||||
@property
|
||||
def default_generic_approximation(self):
|
||||
return self._default_generic_approximation
|
||||
|
||||
@default_generic_approximation.setter
|
||||
def default_generic_approximation(self, value):
|
||||
if value not in _GENERIC_APPROX_TO_BLOCK_TYPES:
|
||||
raise ValueError(
|
||||
"{} is not a valid approximation for generic variables.".format(
|
||||
value))
|
||||
self._default_generic_approximation = value
|
||||
|
||||
@property
|
||||
def default_fully_connected_approximation(self):
|
||||
return self._default_fully_connected_approximation
|
||||
|
||||
@default_fully_connected_approximation.setter
|
||||
def default_fully_connected_approximation(self, value):
|
||||
if value not in _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES:
|
||||
raise ValueError(
|
||||
"{} is not a valid approximation for fully connected layers.".format(
|
||||
value))
|
||||
self._default_fully_connected_approximation = value
|
||||
|
||||
@property
|
||||
def default_conv2d_approximation(self):
|
||||
return self._default_convolution_2d_approximation
|
||||
|
||||
@default_conv2d_approximation.setter
|
||||
def default_conv2d_approximation(self, value):
|
||||
if value not in _CONV2D_APPROX_TO_BLOCK_TYPES:
|
||||
raise ValueError(
|
||||
"{} is not a valid approximation for 2d convolutional layers.".format(
|
||||
value))
|
||||
self._default_convolution_2d_approximation = value
|
||||
|
||||
def register_block(self, layer_key, fisher_block, reuse=VARIABLE_SCOPE):
|
||||
"""Validates and registers the layer_key associated with the fisher_block.
|
||||
|
||||
@ -187,7 +278,8 @@ class LayerCollection(object):
|
||||
# Find all keys that are either supersets or subsets of 'layer_key'.
|
||||
inclusions = {
|
||||
fisher_elt
|
||||
for layer_elt in layer_key for fisher_elt in self.fisher_blocks
|
||||
for layer_elt in layer_key
|
||||
for fisher_elt in self.fisher_blocks
|
||||
if self._equal_or_subset(layer_elt, fisher_elt)
|
||||
}
|
||||
|
||||
@ -294,6 +386,49 @@ class LayerCollection(object):
|
||||
def subgraph(self):
|
||||
return self._subgraph
|
||||
|
||||
def define_linked_parameters(self, params, approximation=None):
|
||||
"""Identify a set of parameters that should be grouped together.
|
||||
|
||||
During automatic graph scanning, any matches containing variables that have
|
||||
been identified as part of a linked group will be filtered out unless
|
||||
the match parameters are exactly equal to the ones specified in the linked
|
||||
group.
|
||||
|
||||
Args:
|
||||
params: A variable, or a tuple or list of variables. The variables
|
||||
to be linked.
|
||||
approximation: Optional string specifying the type of approximation to use
|
||||
for these variables. If unspecified, this layer collection's default
|
||||
approximation for the layer type will be used.
|
||||
|
||||
Raises:
|
||||
ValueError: If the parameters were already registered in a layer or
|
||||
identified as part of an incompatible group.
|
||||
"""
|
||||
params = frozenset(ensure_sequence(params))
|
||||
|
||||
# Check if any of the variables in 'params' is already in
|
||||
# 'self.fisher_blocks.keys()'.
|
||||
for registered_params, fisher_block in self.fisher_blocks.items():
|
||||
registered_params_set = set(ensure_sequence(registered_params))
|
||||
for variable in params:
|
||||
if (variable in registered_params_set and
|
||||
params != registered_params_set):
|
||||
raise ValueError(
|
||||
"Can't link parameters {}, variable {} was already registered in "
|
||||
"group {} with layer {}".format(params, variable,
|
||||
registered_params, fisher_block))
|
||||
|
||||
# Check if any of the variables in 'params' is already in
|
||||
# 'self.linked_parameters'.
|
||||
for variable in params:
|
||||
for other_linked_params in self.linked_parameters:
|
||||
if variable in other_linked_params:
|
||||
raise ValueError("Can't link parameters {}, variable {} was already "
|
||||
"linked in group {}.".format(params, variable,
|
||||
other_linked_params))
|
||||
self._linked_parameters[params] = approximation
|
||||
|
||||
def create_subgraph(self):
|
||||
if not self.losses:
|
||||
raise ValueError("Must have at least one registered loss.")
|
||||
@ -307,11 +442,19 @@ class LayerCollection(object):
|
||||
return math_ops.add_n(
|
||||
tuple(loss.evaluate_on_sample() for loss in self.losses))
|
||||
|
||||
def _get_linked_approx(self, params):
|
||||
"""If params were linked, return their specified approximation."""
|
||||
params_set = frozenset(ensure_sequence(params))
|
||||
if params_set in self.linked_parameters:
|
||||
return self.linked_parameters[params_set]
|
||||
else:
|
||||
return None
|
||||
|
||||
def register_fully_connected(self,
|
||||
params,
|
||||
inputs,
|
||||
outputs,
|
||||
approx=APPROX_KRONECKER_NAME,
|
||||
approx=None,
|
||||
reuse=VARIABLE_SCOPE):
|
||||
"""Registers a fully connnected layer.
|
||||
|
||||
@ -332,15 +475,15 @@ class LayerCollection(object):
|
||||
KeyError: If reuse == True but no FisherBlock found for 'params'.
|
||||
ValueError: If reuse == True and FisherBlock found but of the wrong type.
|
||||
"""
|
||||
approx_to_block_types = {
|
||||
APPROX_KRONECKER_NAME: fb.FullyConnectedKFACBasicFB,
|
||||
APPROX_DIAGONAL_NAME: fb.FullyConnectedDiagonalFB,
|
||||
}
|
||||
if approx is None:
|
||||
approx = self._get_linked_approx(params)
|
||||
if approx is None:
|
||||
approx = self.default_fully_connected_approximation
|
||||
|
||||
if approx not in approx_to_block_types:
|
||||
if approx not in _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES:
|
||||
raise ValueError("Bad value {} for approx.".format(approx))
|
||||
|
||||
block_type = approx_to_block_types[approx]
|
||||
block_type = _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES[approx]
|
||||
has_bias = isinstance(params, (tuple, list))
|
||||
|
||||
block = self.register_block(params, block_type(self, has_bias), reuse=reuse)
|
||||
@ -352,7 +495,7 @@ class LayerCollection(object):
|
||||
padding,
|
||||
inputs,
|
||||
outputs,
|
||||
approx=APPROX_KRONECKER_NAME,
|
||||
approx=None,
|
||||
reuse=VARIABLE_SCOPE):
|
||||
"""Registers a convolutional layer.
|
||||
|
||||
@ -377,15 +520,16 @@ class LayerCollection(object):
|
||||
KeyError: If reuse == True but no FisherBlock found for 'params'.
|
||||
ValueError: If reuse == True and FisherBlock found but of the wrong type.
|
||||
"""
|
||||
approx_to_block_types = {
|
||||
APPROX_KRONECKER_NAME: fb.ConvKFCBasicFB,
|
||||
APPROX_DIAGONAL_NAME: fb.ConvDiagonalFB,
|
||||
}
|
||||
|
||||
if approx not in approx_to_block_types:
|
||||
if approx is None:
|
||||
approx = self._get_linked_approx(params)
|
||||
if approx is None:
|
||||
approx = self.default_conv2d_approximation
|
||||
|
||||
if approx not in _CONV2D_APPROX_TO_BLOCK_TYPES:
|
||||
raise ValueError("Bad value {} for approx.".format(approx))
|
||||
|
||||
block_type = approx_to_block_types[approx]
|
||||
block_type = _CONV2D_APPROX_TO_BLOCK_TYPES[approx]
|
||||
block = self.register_block(
|
||||
params, block_type(self, params, strides, padding), reuse=reuse)
|
||||
block.register_additional_minibatch(inputs, outputs)
|
||||
@ -393,7 +537,7 @@ class LayerCollection(object):
|
||||
def register_generic(self,
|
||||
params,
|
||||
batch_size,
|
||||
approx=APPROX_DIAGONAL_NAME,
|
||||
approx=None,
|
||||
reuse=VARIABLE_SCOPE):
|
||||
"""Registers a generic layer.
|
||||
|
||||
@ -413,15 +557,16 @@ class LayerCollection(object):
|
||||
KeyError: If reuse == True but no FisherBlock found for 'params'.
|
||||
ValueError: If reuse == True and FisherBlock found but of the wrong type.
|
||||
"""
|
||||
approx_to_block_types = {
|
||||
APPROX_FULL_NAME: fb.FullFB,
|
||||
APPROX_DIAGONAL_NAME: fb.NaiveDiagonalFB,
|
||||
}
|
||||
|
||||
if approx not in approx_to_block_types:
|
||||
if approx is None:
|
||||
approx = self._get_linked_approx(params)
|
||||
if approx is None:
|
||||
approx = self.default_generic_approximation
|
||||
|
||||
if approx not in _GENERIC_APPROX_TO_BLOCK_TYPES:
|
||||
raise ValueError("Bad value {} for approx.".format(approx))
|
||||
|
||||
block_type = approx_to_block_types[approx]
|
||||
block_type = _GENERIC_APPROX_TO_BLOCK_TYPES[approx]
|
||||
block = self.register_block(params, block_type(self, params), reuse=reuse)
|
||||
block.register_additional_minibatch(batch_size)
|
||||
|
||||
@ -560,10 +705,10 @@ class LayerCollection(object):
|
||||
try:
|
||||
hash(args)
|
||||
except TypeError:
|
||||
raise TypeError((
|
||||
"Unable to use (cls, args) = ({}, {}) as a key in "
|
||||
"LayerCollection.fisher_factors. The pair cannot be hashed."
|
||||
).format(cls, args))
|
||||
raise TypeError(
|
||||
("Unable to use (cls, args) = ({}, {}) as a key in "
|
||||
"LayerCollection.fisher_factors. The pair cannot be hashed.").format(
|
||||
cls, args))
|
||||
|
||||
with variable_scope.variable_scope(self._var_scope):
|
||||
return utils.setdefault(self.fisher_factors, (cls, args),
|
||||
|
@ -44,7 +44,8 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
|
||||
momentum=0.,
|
||||
momentum_type="regular",
|
||||
norm_constraint=None,
|
||||
name="KFAC",):
|
||||
name="KFAC",
|
||||
estimation_mode="gradients"):
|
||||
"""Initializes the KFAC optimizer with the given settings.
|
||||
|
||||
Args:
|
||||
@ -72,6 +73,10 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
|
||||
specified value. May only be used with momentum type 'regular'.
|
||||
(Default: None)
|
||||
name: The name for this optimizer. (Default: 'KFAC')
|
||||
estimation_mode: The type of estimator to use for the Fishers. Can be
|
||||
'gradients', 'empirical', 'curvature_propagation', or 'exact'.
|
||||
(Default: 'gradients'). See the doc-string for FisherEstimator for
|
||||
more a more detailed description of these options.
|
||||
|
||||
Raises:
|
||||
ValueError: If the momentum type is unsupported.
|
||||
@ -86,7 +91,8 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
|
||||
variables = tf_variables.trainable_variables()
|
||||
|
||||
self._fisher_est = est.FisherEstimator(variables, cov_ema_decay, damping,
|
||||
layer_collection)
|
||||
layer_collection,
|
||||
estimation_mode=estimation_mode)
|
||||
|
||||
momentum_type = momentum_type.lower()
|
||||
legal_momentum_types = ["regular", "adam", "qmodel"]
|
||||
|
@ -1403,7 +1403,8 @@ def dropout(inputs,
|
||||
noise_shape=None,
|
||||
is_training=True,
|
||||
outputs_collections=None,
|
||||
scope=None):
|
||||
scope=None,
|
||||
seed=None):
|
||||
"""Returns a dropout op applied to the input.
|
||||
|
||||
With probability `keep_prob`, outputs the input element scaled up by
|
||||
@ -1421,6 +1422,8 @@ def dropout(inputs,
|
||||
Otherwise, inputs is returned.
|
||||
outputs_collections: Collection to add the outputs.
|
||||
scope: Optional scope for name_scope.
|
||||
seed: A Python integer. Used to create random seeds. See
|
||||
@{tf.set_random_seed} for behavior.
|
||||
|
||||
Returns:
|
||||
A tensor representing the output of the operation.
|
||||
@ -1430,6 +1433,7 @@ def dropout(inputs,
|
||||
inputs = ops.convert_to_tensor(inputs)
|
||||
layer = core_layers.Dropout(rate=1 - keep_prob,
|
||||
noise_shape=noise_shape,
|
||||
seed=seed,
|
||||
name=sc.name,
|
||||
_scope=sc)
|
||||
outputs = layer.apply(inputs, training=is_training)
|
||||
|
@ -1345,11 +1345,20 @@ class DropoutTest(test.TestCase):
|
||||
num_elem_initial = math_ops.reduce_mean(math_ops.to_float(images > 0))
|
||||
output = _layers.dropout(images)
|
||||
num_elem = math_ops.reduce_mean(math_ops.to_float(output > 0))
|
||||
sess.run(variables_lib.global_variables_initializer())
|
||||
num_elem, num_elem_initial = sess.run([num_elem, num_elem_initial])
|
||||
self.assertLess(num_elem, num_elem_initial / 2 + 0.1)
|
||||
self.assertGreater(num_elem, num_elem_initial / 2 - 0.1)
|
||||
|
||||
def testDropoutSeed(self):
|
||||
"""Test that providing the same seed produces the same result."""
|
||||
height, width = 10, 10
|
||||
with self.test_session() as sess:
|
||||
images = random_ops.random_uniform(
|
||||
(5, height, width, 3), seed=1, name='images')
|
||||
output1 = _layers.dropout(images, seed=1)
|
||||
output2 = _layers.dropout(images, seed=1)
|
||||
self.assertAllEqual(*sess.run([output1, output2]))
|
||||
|
||||
def testCreateDropoutNoTraining(self):
|
||||
height, width = 3, 3
|
||||
with self.test_session() as sess:
|
||||
@ -1358,7 +1367,6 @@ class DropoutTest(test.TestCase):
|
||||
num_elem_initial = math_ops.reduce_mean(math_ops.to_float(images > 0))
|
||||
output = _layers.dropout(images, is_training=False)
|
||||
num_elem = math_ops.reduce_mean(math_ops.to_float(output > 0))
|
||||
sess.run(variables_lib.global_variables_initializer())
|
||||
num_elem, num_elem_initial = sess.run([num_elem, num_elem_initial])
|
||||
self.assertEqual(num_elem, num_elem_initial)
|
||||
outputs, inputs = sess.run([output, images])
|
||||
|
@ -33,6 +33,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
from tensorflow.contrib.layers.python.layers import feature_column
|
||||
@ -644,18 +645,22 @@ def make_best_model_export_strategy(serving_input_fn,
|
||||
|
||||
# TODO(b/67013778): Revisit this approach when corresponding changes to
|
||||
# TF Core are finalized.
|
||||
def extend_export_strategy(base_export_strategy, post_export_fn,
|
||||
post_export_name):
|
||||
def extend_export_strategy(base_export_strategy,
|
||||
post_export_fn,
|
||||
post_export_name=None):
|
||||
"""Extend ExportStrategy, calling post_export_fn after export.
|
||||
|
||||
Args:
|
||||
base_export_strategy: An ExportStrategy that can be passed to the Experiment
|
||||
constructor.
|
||||
post_export_fn: A user-specified function to call after exporting the
|
||||
SavedModel. Takes the export directory as an argument, and returns
|
||||
a string path to a (potentially different) SavedModel.
|
||||
SavedModel. Takes two arguments - the path to the SavedModel exported by
|
||||
base_export_strategy and the directory where to export the SavedModel
|
||||
modified by the post_export_fn. Returns the path to the exported
|
||||
SavedModel.
|
||||
post_export_name: The directory name under the export base directory where
|
||||
SavedModels generated by the post_export_fn will be written.
|
||||
SavedModels generated by the post_export_fn will be written. If None, the
|
||||
directory name of base_export_strategy is used.
|
||||
|
||||
Returns:
|
||||
An ExportStrategy that can be passed to the Experiment constructor.
|
||||
@ -675,12 +680,24 @@ def extend_export_strategy(base_export_strategy, post_export_fn,
|
||||
|
||||
Raises:
|
||||
ValueError: If `estimator` is a ${tf.estimator.Estimator} instance
|
||||
and `default_output_alternative_key` was specified.
|
||||
and `default_output_alternative_key` was specified or if post_export_fn
|
||||
does not return a valid directory.
|
||||
"""
|
||||
export_dir = base_export_strategy.export(estimator, export_dir_base,
|
||||
checkpoint_path)
|
||||
if post_export_fn:
|
||||
export_dir = post_export_fn(export_dir)
|
||||
return export_dir
|
||||
tmp_base_export_dir = tempfile.mkdtemp()
|
||||
tmp_base_export = base_export_strategy.export(
|
||||
estimator, tmp_base_export_dir, checkpoint_path)
|
||||
tmp_post_export_dir = tempfile.mkdtemp()
|
||||
tmp_post_export = post_export_fn(tmp_base_export, tmp_post_export_dir)
|
||||
|
||||
return export_strategy.ExportStrategy(post_export_name, export_fn)
|
||||
if not tmp_post_export.startswith(tmp_post_export_dir):
|
||||
raise ValueError('post_export_fn must return a sub-directory of {}'
|
||||
.format(tmp_post_export_dir))
|
||||
export_relpath = os.path.relpath(tmp_post_export, tmp_post_export_dir)
|
||||
|
||||
gfile.Rename(
|
||||
os.path.join(tmp_post_export_dir, export_relpath),
|
||||
os.path.join(export_dir_base, export_relpath))
|
||||
return os.path.join(export_dir_base, export_relpath)
|
||||
|
||||
name = post_export_name if post_export_name else base_export_strategy.name
|
||||
return export_strategy.ExportStrategy(name, export_fn)
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user