commit
ae22de6596
@ -149,4 +149,12 @@ ClientLibrary::GetOrCreateCompileOnlyClient(
|
||||
return cl;
|
||||
}
|
||||
|
||||
/* static */ void ClientLibrary::DestroyLocalInstances() {
|
||||
ClientLibrary& client_library = Singleton();
|
||||
tensorflow::mutex_lock lock(client_library.service_mutex_);
|
||||
|
||||
client_library.local_instances_.clear();
|
||||
client_library.compile_only_instances_.clear();
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -93,6 +93,11 @@ class ClientLibrary {
|
||||
static StatusOr<CompileOnlyClient*> GetOrCreateCompileOnlyClient(
|
||||
perftools::gputools::Platform* platform = nullptr);
|
||||
|
||||
// Clears the local instance and compile only instance caches. The client
|
||||
// pointers returned by the previous GetOrCreateLocalClient() or
|
||||
// GetOrCreateCompileOnlyClient() invocations are not valid anymore.
|
||||
static void DestroyLocalInstances();
|
||||
|
||||
private:
|
||||
// Returns the singleton instance of ClientLibrary.
|
||||
static ClientLibrary& Singleton();
|
||||
|
@ -1208,6 +1208,7 @@ cc_test(
|
||||
"//tensorflow/compiler/xla/client:computation_builder",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/client:padding",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
|
@ -1586,6 +1586,9 @@ StatusOr<bool> AlgebraicSimplifier::Run(HloModule* module) {
|
||||
// module, invalidating iteration.
|
||||
std::vector<HloComputation*> computations;
|
||||
for (auto& comp : module->computations()) {
|
||||
if (comp->IsFusionComputation()) {
|
||||
continue;
|
||||
}
|
||||
computations.push_back(comp.get());
|
||||
}
|
||||
for (auto& comp : computations) {
|
||||
|
@ -268,6 +268,9 @@ StatusOr<bool> BatchNormRewriter::Run(HloModule* module) {
|
||||
// module, invalidating iteration.
|
||||
std::vector<HloComputation*> computations;
|
||||
for (auto& comp : module->computations()) {
|
||||
if (comp->IsFusionComputation()) {
|
||||
continue;
|
||||
}
|
||||
computations.push_back(comp.get());
|
||||
}
|
||||
for (auto& comp : computations) {
|
||||
|
@ -1219,6 +1219,9 @@ void BufferAssigner::BuildColocatedBufferSets(
|
||||
const TuplePointsToAnalysis& points_to_analysis =
|
||||
buffer_liveness.points_to_analysis();
|
||||
for (const HloComputation* computation : module->MakeComputationPostOrder()) {
|
||||
if (computation->IsFusionComputation()) {
|
||||
continue;
|
||||
}
|
||||
for (const HloInstruction* instruction :
|
||||
computation->MakeInstructionPostOrder()) {
|
||||
const HloOpcode opcode = instruction->opcode();
|
||||
@ -1386,6 +1389,9 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
|
||||
// their own BufferAllocation.
|
||||
for (auto* computation : thread_local_computations) {
|
||||
TF_RET_CHECK(computation != module->entry_computation());
|
||||
if (computation->IsFusionComputation()) {
|
||||
continue;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(AssignBuffersForComputation(
|
||||
computation, module->config().debug_options(),
|
||||
/*is_thread_local=*/true, colocated_buffers, colocated_allocations,
|
||||
|
@ -47,6 +47,9 @@ StatusOr<std::unique_ptr<BufferLiveness>> BufferLiveness::Run(
|
||||
tensorflow::Status BufferLiveness::Analyze() {
|
||||
TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module_));
|
||||
for (auto& computation : module_->computations()) {
|
||||
if (computation->IsFusionComputation()) {
|
||||
continue;
|
||||
}
|
||||
// Gather all instructions whose buffers might alias other instructions into
|
||||
// the set aliased_buffers_. This includes those contained as a tuple
|
||||
// element in other instruction's output.
|
||||
|
@ -551,6 +551,9 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) {
|
||||
// Add copies of computation root instructions, if needed.
|
||||
FlatMap<const HloComputation*, ShapeTree<bool>> while_body_read_only_indices;
|
||||
for (auto& computation : module->computations()) {
|
||||
if (computation->IsFusionComputation()) {
|
||||
continue;
|
||||
}
|
||||
VLOG(2) << "computation " << computation->name();
|
||||
InstructionCopier root_copier(computation->root_instruction(),
|
||||
/*copy_users=*/{});
|
||||
|
@ -519,6 +519,9 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
|
||||
new std::map<HloInstruction*, string>());
|
||||
for (auto embedded_computation :
|
||||
computation->MakeEmbeddedComputationsList()) {
|
||||
if (embedded_computation->IsFusionComputation()) {
|
||||
continue;
|
||||
}
|
||||
auto parallel_computation_iter =
|
||||
parallel_computations.find(embedded_computation);
|
||||
// All parallel computations are considered to be an entry computation for
|
||||
@ -591,6 +594,9 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
|
||||
|
||||
for (auto embedded_computation :
|
||||
computation->MakeEmbeddedComputationsList()) {
|
||||
if (embedded_computation->IsFusionComputation()) {
|
||||
continue;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
ir_emitter
|
||||
.EmitComputation(embedded_computation,
|
||||
@ -755,6 +761,9 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
|
||||
HloComputation* computation = module->entry_computation();
|
||||
for (auto embedded_computation :
|
||||
computation->MakeEmbeddedComputationsList()) {
|
||||
if (embedded_computation->IsFusionComputation()) {
|
||||
continue;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
ir_emitter
|
||||
.EmitComputation(embedded_computation,
|
||||
|
@ -125,6 +125,9 @@ StatusOr<bool> ParallelizationPreparation::Run(HloModule* module) {
|
||||
TF_ASSIGN_OR_RETURN(auto points_to_analysis,
|
||||
TuplePointsToAnalysis::Run(module));
|
||||
for (auto& computation : module->computations()) {
|
||||
if (computation->IsFusionComputation()) {
|
||||
continue;
|
||||
}
|
||||
HloInstruction* root = computation->root_instruction();
|
||||
// Copy root instruction if it does not define its own top-level buffer.
|
||||
// TODO(b/32885001) Remove these copies (at least for the unambiguous case).
|
||||
|
@ -293,12 +293,19 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) {
|
||||
StatusOr<bool> FusionMerger::Run(HloModule* module) {
|
||||
bool changed = false;
|
||||
VLOG(2) << "FusionMerger for module: " << module->name();
|
||||
std::vector<HloComputation*> computations;
|
||||
for (auto& computation : module->computations()) {
|
||||
if (computation->IsFusionComputation()) {
|
||||
continue;
|
||||
}
|
||||
computations.push_back(computation.get());
|
||||
}
|
||||
for (auto& computation : computations) {
|
||||
VLOG(1) << "Before running FusionInstructionMerger for computation: "
|
||||
<< computation->name();
|
||||
XLA_VLOG_LINES(3, computation->ToString());
|
||||
|
||||
FusionInstructionMerger fusion_merger(computation.get());
|
||||
FusionInstructionMerger fusion_merger(computation);
|
||||
TF_RETURN_IF_ERROR(fusion_merger.Run());
|
||||
changed |= fusion_merger.changed();
|
||||
|
||||
|
@ -120,7 +120,8 @@ GpuHloOrdering::GpuHloOrdering(
|
||||
// do that yet since it's hard to ensure that the order here is the order used
|
||||
// by IrEmitterNested. And mismatched ordering bugs would be hard to find.
|
||||
for (auto& computation : module->computations()) {
|
||||
if (computation.get() != module->entry_computation()) {
|
||||
if (computation.get() != module->entry_computation() &&
|
||||
!computation->IsFusionComputation()) {
|
||||
predecessors_.emplace(computation.get(),
|
||||
computation->ComputeReachability());
|
||||
}
|
||||
|
@ -42,6 +42,9 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
|
||||
bool changed = false;
|
||||
|
||||
for (auto& computation : module->computations()) {
|
||||
if (computation->IsFusionComputation()) {
|
||||
continue;
|
||||
}
|
||||
for (auto instruction : computation->MakeInstructionPostOrder()) {
|
||||
// Skip dead code.
|
||||
if (instruction->user_count() == 0 &&
|
||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/user_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
@ -329,7 +330,7 @@ TEST_F(HloCostAnalysisTest, MatmulAndConvolutionCanBeTheSameComputation) {
|
||||
EXPECT_EQ(conv_analysis.flop_count(), matmul_analysis.flop_count());
|
||||
}
|
||||
|
||||
using FusionCostAnalysis = ::testing::Test;
|
||||
using FusionCostAnalysis = HloTestBase;
|
||||
|
||||
TEST_F(FusionCostAnalysis, LoopFusion) {
|
||||
// Do this 4 times with different per-second rates to test the computation of
|
||||
@ -345,32 +346,32 @@ TEST_F(FusionCostAnalysis, LoopFusion) {
|
||||
// mul = Mul(exp, C3)
|
||||
// sub = Sub(mul, clamp)
|
||||
// tuple = Tuple({sub, sub, mul, C1})
|
||||
auto c1 = HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
|
||||
/*from=*/0.0f, /*to=*/1.0f, /*rows=*/2, /*cols=*/2));
|
||||
auto c2 = HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
|
||||
/*from=*/1.0f, /*to=*/2.0f, /*rows=*/2, /*cols=*/2));
|
||||
auto c3 = HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
|
||||
/*from=*/2.0f, /*to=*/3.0f, /*rows=*/2, /*cols=*/2));
|
||||
HloComputation::Builder builder(TestName());
|
||||
auto c1 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
|
||||
/*from=*/0.0f, /*to=*/1.0f, /*rows=*/2, /*cols=*/2)));
|
||||
auto c2 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
|
||||
/*from=*/1.0f, /*to=*/2.0f, /*rows=*/2, /*cols=*/2)));
|
||||
auto c3 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
|
||||
/*from=*/2.0f, /*to=*/3.0f, /*rows=*/2, /*cols=*/2)));
|
||||
auto add = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, c1, c2));
|
||||
auto clamp = builder.AddInstruction(
|
||||
HloInstruction::CreateTernary(r2f32, HloOpcode::kClamp, c2, add, add));
|
||||
auto exp = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(r2f32, HloOpcode::kExp, add));
|
||||
auto mul = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(r2f32, HloOpcode::kMultiply, exp, c3));
|
||||
auto sub = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(r2f32, HloOpcode::kSubtract, mul, clamp));
|
||||
auto tuple = HloInstruction::CreateTuple({sub, sub, mul, c1});
|
||||
|
||||
auto add = HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, c1.get(),
|
||||
c2.get());
|
||||
auto clamp = HloInstruction::CreateTernary(r2f32, HloOpcode::kClamp,
|
||||
c2.get(), add.get(), add.get());
|
||||
auto exp = HloInstruction::CreateUnary(r2f32, HloOpcode::kExp, add.get());
|
||||
auto mul = HloInstruction::CreateBinary(r2f32, HloOpcode::kMultiply,
|
||||
exp.get(), c3.get());
|
||||
auto sub = HloInstruction::CreateBinary(r2f32, HloOpcode::kSubtract,
|
||||
mul.get(), clamp.get());
|
||||
auto tuple = HloInstruction::CreateTuple(
|
||||
{sub.get(), sub.get(), mul.get(), c1.get()});
|
||||
|
||||
auto fusion = HloInstruction::CreateFusion(
|
||||
r2f32, HloInstruction::FusionKind::kLoop, tuple.get());
|
||||
fusion->FuseInstruction(sub.get());
|
||||
fusion->FuseInstruction(mul.get());
|
||||
fusion->FuseInstruction(exp.get());
|
||||
fusion->FuseInstruction(clamp.get());
|
||||
fusion->FuseInstruction(add.get());
|
||||
HloModule module(TestName());
|
||||
auto* computation = module.AddEntryComputation(builder.Build());
|
||||
auto* fusion = computation->CreateFusionInstruction(
|
||||
{sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop);
|
||||
|
||||
// The time given these rates at i == 0 is exactly even among the properties
|
||||
// at 1.0 seconds. For other values, one of the rates is slower so that it
|
||||
@ -398,18 +399,21 @@ TEST_F(FusionCostAnalysis, NoLayout) {
|
||||
Shape shape_without_layout = shape_with_layout;
|
||||
shape_without_layout.clear_layout();
|
||||
|
||||
auto c1 = HloInstruction::CreateConstant(
|
||||
Literal::CreateR4FromArray4D(Array4D<float>(2, 3, 4, 5)));
|
||||
auto c2 = HloInstruction::CreateConstant(Literal::CreateR1<float>({1, 2, 3}));
|
||||
HloComputation::Builder builder(TestName());
|
||||
auto c1 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::CreateR4FromArray4D(Array4D<float>(2, 3, 4, 5))));
|
||||
auto c2 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR1<float>({1, 2, 3})));
|
||||
|
||||
auto broadcast =
|
||||
HloInstruction::CreateBroadcast(shape_without_layout, c2.get(), {1});
|
||||
auto add = HloInstruction::CreateBinary(shape_with_layout, HloOpcode::kAdd,
|
||||
c1.get(), broadcast.get());
|
||||
auto broadcast = builder.AddInstruction(
|
||||
HloInstruction::CreateBroadcast(shape_without_layout, c2, {1}));
|
||||
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
shape_with_layout, HloOpcode::kAdd, c1, broadcast));
|
||||
|
||||
auto fusion = HloInstruction::CreateFusion(
|
||||
shape_with_layout, HloInstruction::FusionKind::kLoop, add.get());
|
||||
fusion->FuseInstruction(broadcast.get());
|
||||
HloModule module(TestName());
|
||||
auto* computation = module.AddEntryComputation(builder.Build());
|
||||
auto* fusion = computation->CreateFusionInstruction(
|
||||
{add, broadcast}, HloInstruction::FusionKind::kLoop);
|
||||
|
||||
HloCostAnalysis fusion_analysis(ShapeSize);
|
||||
ASSERT_IS_OK(fusion->Accept(&fusion_analysis));
|
||||
|
@ -92,6 +92,9 @@ bool CombineConstants(HloComputation* computation, bool is_layout_sensitive) {
|
||||
StatusOr<bool> HloCSE::Run(HloModule* module) {
|
||||
bool changed = false;
|
||||
for (auto& computation : module->computations()) {
|
||||
if (computation->IsFusionComputation()) {
|
||||
continue;
|
||||
}
|
||||
changed |= CombineConstants(computation.get(), is_layout_sensitive_);
|
||||
|
||||
std::list<HloInstruction*> post_order =
|
||||
|
@ -38,6 +38,9 @@ StatusOr<bool> HloDCE::Run(HloModule* module) {
|
||||
bool changed = false;
|
||||
|
||||
for (auto& computation : module->computations()) {
|
||||
if (computation->IsFusionComputation()) {
|
||||
continue;
|
||||
}
|
||||
std::unordered_set<HloInstruction*> live_instructions;
|
||||
TF_RETURN_IF_ERROR(computation->root_instruction()->Accept(
|
||||
[&live_instructions](HloInstruction* instruction) {
|
||||
|
@ -560,19 +560,20 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
|
||||
HloInstruction* instruction_to_fuse) {
|
||||
CHECK_EQ(opcode_, HloOpcode::kFusion);
|
||||
CHECK(instruction_to_fuse->IsFusable());
|
||||
|
||||
if (GetModule()) {
|
||||
XLA_VLOG_LINES(1, GetModule()->ToString());
|
||||
}
|
||||
HloInstruction* clone = nullptr;
|
||||
if (fused_instructions_computation_ == nullptr) {
|
||||
if (called_computations_.empty()) {
|
||||
// New fusion instruction.
|
||||
auto builder = HloComputation::Builder("fused_computation", true);
|
||||
builder.AddInstruction(instruction_to_fuse->Clone(/*suffix=*/""));
|
||||
fused_instructions_computation_ = builder.Build();
|
||||
called_computations_.push_back(
|
||||
CHECK_NOTNULL(GetModule())->AddEmbeddedComputation(builder.Build()));
|
||||
clone = fused_expression_root();
|
||||
clone->parent_fusion_instruction_ = this;
|
||||
} else {
|
||||
CHECK(fused_instructions_computation_ != nullptr &&
|
||||
fused_instructions_computation_->IsFusionComputation());
|
||||
clone = fused_instructions_computation_->AddInstruction(
|
||||
clone = fused_instructions_computation()->AddInstruction(
|
||||
instruction_to_fuse->Clone(/*suffix=*/""));
|
||||
clone->parent_fusion_instruction_ = this;
|
||||
// instruction_to_fuse is necessarily an operand of the fusion instruction.
|
||||
@ -583,7 +584,7 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
|
||||
CHECK(std::find(operands_.begin(), operands_.end(), instruction_to_fuse) !=
|
||||
operands_.end());
|
||||
const std::vector<HloInstruction*>& fused_parameters_ =
|
||||
fused_instructions_computation_->parameter_instructions();
|
||||
fused_instructions_computation()->parameter_instructions();
|
||||
for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) {
|
||||
if (instruction_to_fuse == operands_[operand_num]) {
|
||||
// replace the fused parameter instruction's uses with the clone.
|
||||
@ -593,7 +594,7 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
|
||||
// Remove the corresponding fused parameter and operand from their
|
||||
// respective vectors.
|
||||
TF_CHECK_OK(
|
||||
fused_instructions_computation_->RemoveParameter(operand_num));
|
||||
fused_instructions_computation()->RemoveParameter(operand_num));
|
||||
operands_.erase(operands_.begin() + operand_num);
|
||||
break;
|
||||
}
|
||||
@ -605,7 +606,7 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
|
||||
|
||||
// Reread the parameters in the computation.
|
||||
const std::vector<HloInstruction*>& fused_parameters_ =
|
||||
fused_instructions_computation_->parameter_instructions();
|
||||
fused_instructions_computation()->parameter_instructions();
|
||||
|
||||
// Add each operand of the clone as an operand of the fusion instruction. A
|
||||
// complication is that some clone operands may already be operands of the
|
||||
@ -638,7 +639,7 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
|
||||
CreateParameter(param_no, operand->shape(), param_name);
|
||||
|
||||
param_instruction->parent_fusion_instruction_ = this;
|
||||
fused_param = fused_instructions_computation_->AddParameter(
|
||||
fused_param = fused_instructions_computation()->AddParameter(
|
||||
std::move(param_instruction));
|
||||
AppendOperand(operand);
|
||||
}
|
||||
@ -652,7 +653,6 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
|
||||
called_computations_.push_back(computation);
|
||||
}
|
||||
}
|
||||
|
||||
return clone;
|
||||
}
|
||||
|
||||
@ -663,17 +663,15 @@ RandomDistribution HloInstruction::random_distribution() const {
|
||||
|
||||
void HloInstruction::CheckFusionInstruction() const {
|
||||
CHECK_EQ(opcode_, HloOpcode::kFusion);
|
||||
CHECK(fused_instructions_computation_ != nullptr &&
|
||||
fused_instructions_computation_->IsFusionComputation());
|
||||
|
||||
const std::list<std::unique_ptr<HloInstruction>>& fused_instructions_ =
|
||||
fused_instructions_computation_->instructions();
|
||||
fused_instructions_computation()->instructions();
|
||||
// All instructions owned by this fusion instruction must be fused, and the
|
||||
// parent fusion instruction of the fused instructions must be 'this'.
|
||||
for (auto& instruction : fused_instructions_) {
|
||||
CHECK(instruction->IsFused());
|
||||
CHECK_EQ(this, instruction->fusion_instruction());
|
||||
CHECK_EQ(fused_instructions_computation_.get(), instruction->parent())
|
||||
CHECK_EQ(fused_instructions_computation(), instruction->parent())
|
||||
<< instruction->ToString();
|
||||
}
|
||||
|
||||
@ -976,8 +974,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands(
|
||||
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
|
||||
CHECK_EQ(opcode_, HloOpcode::kFusion);
|
||||
CHECK(parent() != nullptr);
|
||||
CHECK(fused_instructions_computation_ != nullptr &&
|
||||
fused_instructions_computation_->IsFusionComputation());
|
||||
|
||||
auto new_instruction =
|
||||
WrapUnique(new HloInstruction(HloOpcode::kFusion, shape));
|
||||
@ -992,9 +988,9 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands(
|
||||
// fused instructions.
|
||||
std::vector<HloInstruction*> new_fused_parameters;
|
||||
const std::vector<HloInstruction*>& fused_parameters_ =
|
||||
fused_instructions_computation_->parameter_instructions();
|
||||
fused_instructions_computation()->parameter_instructions();
|
||||
const std::list<std::unique_ptr<HloInstruction>>& fused_instructions_ =
|
||||
fused_instructions_computation_->instructions();
|
||||
fused_instructions_computation()->instructions();
|
||||
|
||||
for (HloInstruction* old_fused_parameter : fused_parameters_) {
|
||||
new_fused_instructions.push_back(old_fused_parameter->Clone());
|
||||
@ -1028,7 +1024,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands(
|
||||
}
|
||||
new_instruction->fusion_kind_ = fusion_kind_;
|
||||
auto computation_builder = HloComputation::Builder(
|
||||
fused_instructions_computation_->name() + ".clone", true);
|
||||
fused_instructions_computation()->name() + ".clone", true);
|
||||
// We iterated the fusion instructions in reverse post order which means
|
||||
// that we must reverse our new list of fusion instructions.
|
||||
for (auto new_fused_instruction_iter = new_fused_instructions.rbegin();
|
||||
@ -1037,8 +1033,10 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands(
|
||||
computation_builder.AddInstruction(std::move(*new_fused_instruction_iter));
|
||||
}
|
||||
auto fused_root_ = fused_expression_root();
|
||||
new_instruction->fused_instructions_computation_ =
|
||||
computation_builder.Build(FindOrDie(old_to_new, fused_root_));
|
||||
new_instruction->called_computations_.push_back(
|
||||
CHECK_NOTNULL(GetModule())
|
||||
->AddEmbeddedComputation(
|
||||
computation_builder.Build(FindOrDie(old_to_new, fused_root_))));
|
||||
new_instruction->set_parent(parent());
|
||||
new_instruction->CheckFusionInstruction();
|
||||
return new_instruction;
|
||||
@ -1769,7 +1767,10 @@ bool HloInstruction::IsFusable() const {
|
||||
|
||||
HloComputation* HloInstruction::fused_instructions_computation() const {
|
||||
CHECK_EQ(opcode_, HloOpcode::kFusion);
|
||||
return fused_instructions_computation_.get();
|
||||
CHECK(!called_computations_.empty());
|
||||
auto* fused_instructions_computation = called_computations_.front();
|
||||
CHECK(fused_instructions_computation->IsFusionComputation());
|
||||
return fused_instructions_computation;
|
||||
}
|
||||
|
||||
HloInstruction* HloInstruction::fusion_instruction() const {
|
||||
@ -1779,32 +1780,24 @@ HloInstruction* HloInstruction::fusion_instruction() const {
|
||||
|
||||
HloInstruction* HloInstruction::fused_expression_root() const {
|
||||
CHECK_EQ(opcode_, HloOpcode::kFusion);
|
||||
CHECK(fused_instructions_computation_ != nullptr &&
|
||||
fused_instructions_computation_->IsFusionComputation());
|
||||
return fused_instructions_computation_->root_instruction();
|
||||
return fused_instructions_computation()->root_instruction();
|
||||
}
|
||||
|
||||
HloInstruction* HloInstruction::fused_parameter(int64 parameter_number) const {
|
||||
CHECK_EQ(opcode_, HloOpcode::kFusion);
|
||||
CHECK(fused_instructions_computation_ != nullptr &&
|
||||
fused_instructions_computation_->IsFusionComputation());
|
||||
return fused_instructions_computation_->parameter_instruction(
|
||||
return fused_instructions_computation()->parameter_instruction(
|
||||
parameter_number);
|
||||
}
|
||||
|
||||
const std::vector<HloInstruction*>& HloInstruction::fused_parameters() const {
|
||||
CHECK_EQ(opcode_, HloOpcode::kFusion);
|
||||
CHECK(fused_instructions_computation_ != nullptr &&
|
||||
fused_instructions_computation_->IsFusionComputation());
|
||||
return fused_instructions_computation_->parameter_instructions();
|
||||
return fused_instructions_computation()->parameter_instructions();
|
||||
}
|
||||
|
||||
const std::list<std::unique_ptr<HloInstruction>>&
|
||||
HloInstruction::fused_instructions() const {
|
||||
CHECK_EQ(opcode_, HloOpcode::kFusion);
|
||||
CHECK(fused_instructions_computation_ != nullptr &&
|
||||
fused_instructions_computation_->IsFusionComputation());
|
||||
return fused_instructions_computation_->instructions();
|
||||
return fused_instructions_computation()->instructions();
|
||||
}
|
||||
|
||||
HloInstruction::HloInstruction(HloOpcode opcode, const Shape& shape)
|
||||
@ -2039,7 +2032,7 @@ static Status PostOrderDFS(HloInstruction* root, DfsHloVisitor* visitor,
|
||||
|
||||
Status HloInstruction::Accept(DfsHloVisitor* visitor, bool call_finish_visit,
|
||||
bool ignore_control_predecessors) {
|
||||
VLOG(2) << "HloInstruction::Accept(" << name() << ")";
|
||||
VLOG(3) << "HloInstruction::Accept(" << name() << ")";
|
||||
TF_RETURN_IF_ERROR(
|
||||
PostOrderDFS(this, visitor, nullptr, ignore_control_predecessors));
|
||||
if (call_finish_visit) {
|
||||
@ -2055,8 +2048,11 @@ Status HloInstruction::AcceptWithOperandOrder(
|
||||
TF_RETURN_IF_ERROR(PostOrderDFS(this, visitor, &operand_order,
|
||||
/*ignore_control_predecessors=*/false));
|
||||
if (call_finish_visit) {
|
||||
VLOG(3) << "HloInstruction::AcceptWithOperandOrder BEFORE FINISH VISIT";
|
||||
TF_RETURN_IF_ERROR(visitor->FinishVisit(this));
|
||||
VLOG(3) << "HloInstruction::AcceptWithOperandOrder AFTER FINISH VISIT";
|
||||
}
|
||||
VLOG(2) << "HloInstruction::AcceptWithOperandOrder EXIT";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -2458,6 +2454,7 @@ HloModule* HloInstruction::GetModule() const {
|
||||
}
|
||||
|
||||
void HloInstruction::UniquifyName(NameUniquer* name_uniquer) {
|
||||
string parent_str = parent() == nullptr ? "noparent" : parent()->name();
|
||||
name_ = name_uniquer->GetUniqueName(name_);
|
||||
}
|
||||
|
||||
|
@ -935,10 +935,6 @@ class HloInstruction {
|
||||
// padding of this pad instruction. Only set for pad instructions.
|
||||
std::unique_ptr<PaddingConfig> padding_config_;
|
||||
|
||||
// The computation that stores of instructions fused into this fusion
|
||||
// instruction. Only set for fusion instructions.
|
||||
std::unique_ptr<HloComputation> fused_instructions_computation_;
|
||||
|
||||
// If this instruction is fused into a fusion instruction, this field points
|
||||
// to the fusion instruction.
|
||||
HloInstruction* parent_fusion_instruction_ = nullptr;
|
||||
|
@ -557,78 +557,89 @@ TEST_F(HloInstructionTest, PostProcessAllVisitedNodes) {
|
||||
}
|
||||
|
||||
TEST_F(HloInstructionTest, SingletonFusionOp) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
// Create a fusion instruction containing a single unary operation.
|
||||
auto constant =
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f));
|
||||
auto exp =
|
||||
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant.get());
|
||||
auto constant = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
|
||||
auto exp = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
|
||||
HloModule module(TestName());
|
||||
auto* computation = module.AddEntryComputation(builder.Build());
|
||||
auto* fusion = computation->CreateFusionInstruction(
|
||||
{exp}, HloInstruction::FusionKind::kLoop);
|
||||
|
||||
auto fusion = HloInstruction::CreateFusion(
|
||||
r0f32_, HloInstruction::FusionKind::kLoop, exp.get());
|
||||
|
||||
EXPECT_THAT(fusion->operands(), ElementsAre(constant.get()));
|
||||
EXPECT_THAT(constant->users(), UnorderedElementsAre(fusion.get(), exp.get()));
|
||||
EXPECT_THAT(fusion->operands(), ElementsAre(constant));
|
||||
EXPECT_THAT(constant->users(), ElementsAre(fusion));
|
||||
}
|
||||
|
||||
TEST_F(HloInstructionTest, BinaryFusionOp) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
// Create a fusion instruction containing a single binary operation.
|
||||
auto constant1 =
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f));
|
||||
auto constant2 =
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(42.1f));
|
||||
auto add = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
|
||||
constant1.get(), constant2.get());
|
||||
auto constant1 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
|
||||
auto constant2 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(42.1f)));
|
||||
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
r0f32_, HloOpcode::kAdd, constant1, constant2));
|
||||
HloModule module(TestName());
|
||||
auto* computation = module.AddEntryComputation(builder.Build());
|
||||
auto* fusion = computation->CreateFusionInstruction(
|
||||
{add}, HloInstruction::FusionKind::kLoop);
|
||||
|
||||
auto fusion = HloInstruction::CreateFusion(
|
||||
r0f32_, HloInstruction::FusionKind::kLoop, add.get());
|
||||
|
||||
EXPECT_THAT(fusion->operands(),
|
||||
ElementsAre(constant1.get(), constant2.get()));
|
||||
EXPECT_THAT(constant1->users(),
|
||||
UnorderedElementsAre(fusion.get(), add.get()));
|
||||
EXPECT_THAT(constant2->users(),
|
||||
UnorderedElementsAre(fusion.get(), add.get()));
|
||||
EXPECT_THAT(fusion->operands(), ElementsAre(constant1, constant2));
|
||||
EXPECT_THAT(constant1->users(), ElementsAre(fusion));
|
||||
EXPECT_THAT(constant2->users(), ElementsAre(fusion));
|
||||
}
|
||||
|
||||
TEST_F(HloInstructionTest, ChainFusionOp) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
// Create a chain of fused unary ops.
|
||||
auto constant =
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f));
|
||||
auto exp1 =
|
||||
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant.get());
|
||||
auto exp2 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1.get());
|
||||
auto exp3 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp2.get());
|
||||
auto constant = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
|
||||
auto exp1 = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
|
||||
auto exp2 = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1));
|
||||
auto exp3 = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp2));
|
||||
|
||||
auto fusion = HloInstruction::CreateFusion(
|
||||
r0f32_, HloInstruction::FusionKind::kLoop, exp3.get());
|
||||
fusion->FuseInstruction(exp2.get());
|
||||
fusion->FuseInstruction(exp1.get());
|
||||
HloModule module(TestName());
|
||||
auto* computation = module.AddEntryComputation(builder.Build());
|
||||
auto* fusion = computation->CreateFusionInstruction(
|
||||
{exp3, exp2, exp1}, HloInstruction::FusionKind::kLoop);
|
||||
|
||||
EXPECT_THAT(fusion->operands(), ElementsAre(constant.get()));
|
||||
EXPECT_THAT(constant->users(),
|
||||
UnorderedElementsAre(fusion.get(), exp1.get()));
|
||||
EXPECT_THAT(fusion->operands(), ElementsAre(constant));
|
||||
EXPECT_THAT(constant->users(), ElementsAre(fusion));
|
||||
}
|
||||
|
||||
TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
// Create a chain of fused unary ops.
|
||||
auto constant =
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f));
|
||||
auto exp1 =
|
||||
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant.get());
|
||||
auto exp2 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1.get());
|
||||
auto constant = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
|
||||
auto exp1 = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
|
||||
auto exp2 = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1));
|
||||
OpMetadata metadata;
|
||||
metadata.set_op_name("tf_op");
|
||||
exp1->set_metadata(metadata);
|
||||
exp2->set_metadata(metadata);
|
||||
|
||||
auto fusion = HloInstruction::CreateFusion(
|
||||
r0f32_, HloInstruction::FusionKind::kLoop, exp2.get());
|
||||
auto* fused = fusion->FuseInstruction(exp1.get());
|
||||
HloModule module(TestName());
|
||||
auto* computation = module.AddEntryComputation(builder.Build());
|
||||
auto* fusion = computation->CreateFusionInstruction(
|
||||
{exp2, exp1}, HloInstruction::FusionKind::kLoop);
|
||||
|
||||
EXPECT_TRUE(protobuf_util::ProtobufEquals(metadata, fusion->metadata()));
|
||||
EXPECT_TRUE(protobuf_util::ProtobufEquals(metadata, fused->metadata()));
|
||||
EXPECT_TRUE(protobuf_util::ProtobufEquals(
|
||||
metadata, fusion->fused_expression_root()->metadata()));
|
||||
EXPECT_TRUE(protobuf_util::ProtobufEquals(
|
||||
metadata, fusion->fused_expression_root()->operand(0)->metadata()));
|
||||
}
|
||||
|
||||
TEST_F(HloInstructionTest, FusionOpWithCalledComputations) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
// Create a fusion instruction containing a single unary operation.
|
||||
const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
|
||||
|
||||
@ -642,33 +653,36 @@ TEST_F(HloInstructionTest, FusionOpWithCalledComputations) {
|
||||
std::unique_ptr<HloComputation> computation_x = make_map_computation();
|
||||
std::unique_ptr<HloComputation> computation_y = make_map_computation();
|
||||
|
||||
auto constant =
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f));
|
||||
auto map_1_x =
|
||||
HloInstruction::CreateMap(scalar_shape, {constant.get()},
|
||||
computation_x.get(), /*static_operands=*/{});
|
||||
auto map_2_x =
|
||||
HloInstruction::CreateMap(scalar_shape, {map_1_x.get()},
|
||||
computation_x.get(), /*static_operands=*/{});
|
||||
auto map_3_y =
|
||||
HloInstruction::CreateMap(scalar_shape, {map_2_x.get()},
|
||||
computation_y.get(), /*static_operands=*/{});
|
||||
auto constant = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
|
||||
auto map_1_x = builder.AddInstruction(HloInstruction::CreateMap(
|
||||
scalar_shape, {constant}, computation_x.get(), /*static_operands=*/{}));
|
||||
auto map_2_x = builder.AddInstruction(HloInstruction::CreateMap(
|
||||
scalar_shape, {map_1_x}, computation_x.get(), /*static_operands=*/{}));
|
||||
auto map_3_y = builder.AddInstruction(HloInstruction::CreateMap(
|
||||
scalar_shape, {map_2_x}, computation_y.get(), /*static_operands=*/{}));
|
||||
|
||||
auto fusion = HloInstruction::CreateFusion(
|
||||
scalar_shape, HloInstruction::FusionKind::kLoop, map_3_y.get());
|
||||
|
||||
EXPECT_THAT(fusion->called_computations(), ElementsAre(computation_y.get()));
|
||||
|
||||
fusion->FuseInstruction(map_2_x.get());
|
||||
HloModule module(TestName());
|
||||
auto* computation = module.AddEntryComputation(builder.Build());
|
||||
auto* fusion = computation->CreateFusionInstruction(
|
||||
{map_3_y}, HloInstruction::FusionKind::kLoop);
|
||||
auto* fused_computation = fusion->fused_instructions_computation();
|
||||
EXPECT_THAT(fusion->called_computations(),
|
||||
ElementsAre(computation_y.get(), computation_x.get()));
|
||||
ElementsAre(fused_computation, computation_y.get()));
|
||||
|
||||
fusion->FuseInstruction(map_1_x.get());
|
||||
EXPECT_THAT(fusion->called_computations(),
|
||||
ElementsAre(computation_y.get(), computation_x.get()));
|
||||
fusion->FuseInstruction(map_2_x);
|
||||
EXPECT_THAT(
|
||||
fusion->called_computations(),
|
||||
ElementsAre(fused_computation, computation_y.get(), computation_x.get()));
|
||||
|
||||
fusion->FuseInstruction(map_1_x);
|
||||
EXPECT_THAT(
|
||||
fusion->called_computations(),
|
||||
ElementsAre(fused_computation, computation_y.get(), computation_x.get()));
|
||||
}
|
||||
|
||||
TEST_F(HloInstructionTest, ComplexFusionOp) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
// Fuse all instructions in complicated expression:
|
||||
//
|
||||
// add = Add(C1, C2)
|
||||
@ -680,35 +694,35 @@ TEST_F(HloInstructionTest, ComplexFusionOp) {
|
||||
//
|
||||
// Notable complexities are repeated operands in a same instruction, different
|
||||
// shapes, use of value in different expressions.
|
||||
auto c1 = HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f));
|
||||
auto c2 = HloInstruction::CreateConstant(Literal::CreateR0<float>(2.1f));
|
||||
auto c3 = HloInstruction::CreateConstant(Literal::CreateR0<float>(9.0f));
|
||||
auto c1 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
|
||||
auto c2 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(2.1f)));
|
||||
auto c3 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(9.0f)));
|
||||
|
||||
auto add =
|
||||
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, c1.get(), c2.get());
|
||||
auto clamp = HloInstruction::CreateTernary(r0f32_, HloOpcode::kClamp,
|
||||
c2.get(), add.get(), add.get());
|
||||
auto exp = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, add.get());
|
||||
auto mul = HloInstruction::CreateBinary(r0f32_, HloOpcode::kMultiply,
|
||||
exp.get(), c3.get());
|
||||
auto sub = HloInstruction::CreateBinary(r0f32_, HloOpcode::kSubtract,
|
||||
mul.get(), clamp.get());
|
||||
auto add = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, c1, c2));
|
||||
auto clamp = builder.AddInstruction(
|
||||
HloInstruction::CreateTernary(r0f32_, HloOpcode::kClamp, c2, add, add));
|
||||
auto exp = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, add));
|
||||
auto mul = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(r0f32_, HloOpcode::kMultiply, exp, c3));
|
||||
auto sub = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(r0f32_, HloOpcode::kSubtract, mul, clamp));
|
||||
auto tuple =
|
||||
HloInstruction::CreateTuple({sub.get(), sub.get(), mul.get(), c1.get()});
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({sub, sub, mul, c1}));
|
||||
|
||||
auto fusion = HloInstruction::CreateFusion(
|
||||
r0f32_, HloInstruction::FusionKind::kLoop, tuple.get());
|
||||
fusion->FuseInstruction(sub.get());
|
||||
fusion->FuseInstruction(mul.get());
|
||||
fusion->FuseInstruction(exp.get());
|
||||
fusion->FuseInstruction(clamp.get());
|
||||
fusion->FuseInstruction(add.get());
|
||||
HloModule module(TestName());
|
||||
auto* computation = module.AddEntryComputation(builder.Build());
|
||||
auto* fusion = computation->CreateFusionInstruction(
|
||||
{tuple, sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop);
|
||||
|
||||
// Operands in the fusion instruction's operands() vector should be in the
|
||||
// order in which their users were added fused.
|
||||
EXPECT_THAT(fusion->operands(), ElementsAre(c1.get(), c3.get(), c2.get()));
|
||||
EXPECT_THAT(c1->users(),
|
||||
UnorderedElementsAre(add.get(), tuple.get(), fusion.get()));
|
||||
EXPECT_THAT(fusion->operands(), ElementsAre(c1, c3, c2));
|
||||
EXPECT_THAT(c1->users(), ElementsAre(fusion));
|
||||
}
|
||||
|
||||
// Convenience function for comparing two HloInstructions inside of
|
||||
@ -864,7 +878,8 @@ TEST_F(HloInstructionTest, PartiallyElementwise) {
|
||||
HloInstruction* max = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(r2f32, HloOpcode::kMaximum, div, broadcast));
|
||||
|
||||
auto computation = builder.Build();
|
||||
HloModule module(TestName());
|
||||
auto* computation = module.AddEntryComputation(builder.Build());
|
||||
HloInstruction* fusion = computation->CreateFusionInstruction(
|
||||
{max, broadcast, div, mul}, HloInstruction::FusionKind::kLoop);
|
||||
EXPECT_FALSE(fusion->IsElementwise());
|
||||
@ -906,7 +921,8 @@ TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) {
|
||||
HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
r1f32, HloOpcode::kSubtract, min, broadcast));
|
||||
|
||||
auto computation = builder.Build();
|
||||
HloModule module(TestName());
|
||||
auto* computation = module.AddEntryComputation(builder.Build());
|
||||
HloInstruction* fusion = computation->CreateFusionInstruction(
|
||||
{sub, broadcast, min}, HloInstruction::FusionKind::kLoop);
|
||||
EXPECT_FALSE(fusion->IsElementwise());
|
||||
@ -945,7 +961,8 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) {
|
||||
HloInstruction* dot = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(sout, HloOpcode::kDot, x, reshape));
|
||||
|
||||
auto computation = builder.Build();
|
||||
HloModule module(TestName());
|
||||
auto* computation = module.AddEntryComputation(builder.Build());
|
||||
HloInstruction* fusion = computation->CreateFusionInstruction(
|
||||
{dot, reshape}, HloInstruction::FusionKind::kTransposeDot);
|
||||
|
||||
|
@ -183,6 +183,9 @@ DependencyHloOrdering::DependencyHloOrdering(const HloModule* module)
|
||||
// ordering based on dependencies. ExecutesBefore will return true iff there
|
||||
// exists a path in the HLO computation graph from 'a' to 'b'.
|
||||
for (auto& computation : module->computations()) {
|
||||
if (computation->IsFusionComputation()) {
|
||||
continue;
|
||||
}
|
||||
predecessors_.emplace(computation.get(),
|
||||
computation->ComputeReachability());
|
||||
}
|
||||
|
@ -1202,6 +1202,9 @@ StatusOr<bool> HloRematerialization::Run(
|
||||
// After DCE, the module sequence may include instructions which no longer
|
||||
// exist.
|
||||
for (const auto& computation : module->computations()) {
|
||||
if (computation->IsFusionComputation()) {
|
||||
continue;
|
||||
}
|
||||
if (sequence->at(computation.get()).size() !=
|
||||
computation->instruction_count()) {
|
||||
// A size mismatch between the computation instruction count and the size
|
||||
|
@ -400,6 +400,9 @@ CreateMemoryMinimizingSequence(
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
|
||||
TuplePointsToAnalysis::Run(&module));
|
||||
for (const auto& computation : module.computations()) {
|
||||
if (computation->IsFusionComputation()) {
|
||||
continue;
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(sequence[computation.get()],
|
||||
CreateMemoryMinimizingSequence(
|
||||
*computation, *points_to_analysis, size_function));
|
||||
@ -410,6 +413,7 @@ CreateMemoryMinimizingSequence(
|
||||
StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence(
|
||||
const HloComputation& computation,
|
||||
const LogicalBuffer::SizeFunction& size_function) {
|
||||
CHECK(!computation.IsFusionComputation());
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
|
||||
TuplePointsToAnalysis::Run(computation.parent()));
|
||||
return CreateMemoryMinimizingSequence(computation, *points_to_analysis,
|
||||
|
@ -211,8 +211,17 @@ bool InstructionFusion::CanFuseOnAllPaths(
|
||||
|
||||
StatusOr<bool> InstructionFusion::Run(HloModule* module) {
|
||||
bool changed = false;
|
||||
|
||||
std::vector<HloComputation*> computations;
|
||||
for (auto& computation : module->computations()) {
|
||||
computation_ = computation.get();
|
||||
if (computation->IsFusionComputation()) {
|
||||
continue;
|
||||
}
|
||||
computations.push_back(computation.get());
|
||||
}
|
||||
for (auto& computation : computations) {
|
||||
CHECK(!computation->IsFusionComputation());
|
||||
computation_ = computation;
|
||||
|
||||
// We want to be able to remove arbitrary instructions from the post order
|
||||
// and also compare positions of instructions in the post order. To make
|
||||
|
@ -611,6 +611,9 @@ Status CheckLayouts(
|
||||
TF_ASSIGN_OR_RETURN(auto points_to_analysis,
|
||||
TuplePointsToAnalysis::Run(module));
|
||||
for (auto& computation : module->computations()) {
|
||||
if (computation->IsFusionComputation()) {
|
||||
continue;
|
||||
}
|
||||
for (auto& instruction : computation->instructions()) {
|
||||
// Verify every instruction has a layout and the layout is valid for the
|
||||
// shape.
|
||||
@ -1356,6 +1359,8 @@ StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
|
||||
if (computation == module->entry_computation()) {
|
||||
TF_RETURN_IF_ERROR(RunOnComputation(*entry_computation_layout_,
|
||||
module->entry_computation()));
|
||||
} else if (computation->IsFusionComputation()) {
|
||||
continue;
|
||||
} else {
|
||||
ComputationLayout computation_layout(computation->ComputeProgramShape());
|
||||
// Setting all embedded computations to the default layout is potentially
|
||||
|
@ -29,7 +29,11 @@ string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) {
|
||||
return root;
|
||||
} else {
|
||||
tensorflow::strings::StrAppend(&root, separator_, *count);
|
||||
// Increment lookup under old 'root' name.
|
||||
(*count)++;
|
||||
// Initialize count under new 'root' name.
|
||||
count = &(generated_names_[root]);
|
||||
*count = 1;
|
||||
return root;
|
||||
}
|
||||
}
|
||||
|
@ -26,6 +26,9 @@ StatusOr<bool> ReducePrecisionInsertion::Run(HloModule* module) {
|
||||
VLOG(1) << "Running ReducePrecisionInsertion pass on " << module->name();
|
||||
|
||||
for (auto& computation : module->computations()) {
|
||||
if (computation->IsFusionComputation()) {
|
||||
continue;
|
||||
}
|
||||
std::vector<HloInstruction*> instructions_to_suffix;
|
||||
|
||||
for (auto& instruction : computation->instructions()) {
|
||||
|
@ -312,10 +312,17 @@ StatusOr<bool> TrySinkReshapeOrTranspose(HloComputation* computation,
|
||||
|
||||
StatusOr<bool> ReshapeMover::Run(HloModule* module) {
|
||||
bool changed = false;
|
||||
for (const auto& comp : module->computations()) {
|
||||
std::vector<HloComputation*> computations;
|
||||
for (auto& computation : module->computations()) {
|
||||
if (computation->IsFusionComputation()) {
|
||||
continue;
|
||||
}
|
||||
computations.push_back(computation.get());
|
||||
}
|
||||
for (const auto& comp : computations) {
|
||||
for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) {
|
||||
TF_ASSIGN_OR_RETURN(bool did_change,
|
||||
TrySinkReshapeOrTranspose(comp.get(), instruction));
|
||||
TrySinkReshapeOrTranspose(comp, instruction));
|
||||
changed |= did_change;
|
||||
}
|
||||
}
|
||||
|
@ -351,16 +351,15 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossFusion) {
|
||||
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
root_shape, HloOpcode::kAdd, reshape0, reshape1));
|
||||
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
auto fusion = computation->AddInstruction(HloInstruction::CreateFusion(
|
||||
add->shape(), HloInstruction::FusionKind::kLoop, add));
|
||||
TF_CHECK_OK(computation->ReplaceInstruction(add, fusion));
|
||||
HloModule module(TestName());
|
||||
auto computation = module.AddEntryComputation(builder.Build());
|
||||
computation->CreateFusionInstruction({add},
|
||||
HloInstruction::FusionKind::kLoop);
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(),
|
||||
op::Fusion(op::Reshape(param0), op::Reshape(param1)));
|
||||
|
||||
EXPECT_TRUE(ReshapeMover().Run(module.get()).ValueOrDie());
|
||||
EXPECT_TRUE(ReshapeMover().Run(&module).ValueOrDie());
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(),
|
||||
op::Reshape(op::Fusion(param0, param1)));
|
||||
|
@ -172,7 +172,14 @@ StatusOr<bool> TransposeFolding::Run(HloModule* module) {
|
||||
return tensorflow::Status::OK();
|
||||
};
|
||||
|
||||
for (auto& comp : module->computations()) {
|
||||
std::vector<HloComputation*> computations;
|
||||
for (auto& computation : module->computations()) {
|
||||
if (computation->IsFusionComputation()) {
|
||||
continue;
|
||||
}
|
||||
computations.push_back(computation.get());
|
||||
}
|
||||
for (auto& comp : computations) {
|
||||
TF_RETURN_IF_ERROR(comp->Accept(visit_fn));
|
||||
}
|
||||
|
||||
|
@ -135,6 +135,9 @@ TuplePointsToAnalysis::Run(const HloModule* module) {
|
||||
Status TuplePointsToAnalysis::Analyze() {
|
||||
points_to_.clear();
|
||||
for (auto& computation : module_->computations()) {
|
||||
if (computation->IsFusionComputation()) {
|
||||
continue;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(computation->Accept(this));
|
||||
TF_RETURN_IF_ERROR(
|
||||
PopulateDefinedBuffersAndAliases(computation->instructions()));
|
||||
@ -451,6 +454,9 @@ string TuplePointsToAnalysis::ToString() const {
|
||||
string output = tensorflow::strings::Printf(
|
||||
"TuplePointsToSet for module %s:\n", module_->name().c_str());
|
||||
for (const auto& computation : module_->computations()) {
|
||||
if (computation->IsFusionComputation()) {
|
||||
continue;
|
||||
}
|
||||
const char* entry =
|
||||
computation.get() == module_->entry_computation() ? "entry " : "";
|
||||
tensorflow::strings::StrAppend(&output, entry, "computation ",
|
||||
|
@ -35,23 +35,20 @@ py_library(
|
||||
|
||||
cuda_py_test(
|
||||
name = "csiszar_divergence_test",
|
||||
size = "small",
|
||||
size = "medium",
|
||||
srcs = ["python/kernel_tests/csiszar_divergence_test.py"],
|
||||
additional_deps = [
|
||||
":bayesflow_py",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/contrib/distributions:distributions_py",
|
||||
"//tensorflow/contrib/layers:layers_py",
|
||||
"//tensorflow/python/ops/distributions",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:linalg_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:nn_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
@ -84,12 +81,11 @@ cuda_py_test(
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/contrib/distributions:distributions_py",
|
||||
"//tensorflow/contrib/layers:layers_py",
|
||||
"//tensorflow/python/ops/distributions",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:nn_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:variables",
|
||||
],
|
||||
)
|
||||
|
@ -20,22 +20,24 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib import distributions as distributions_lib
|
||||
from tensorflow.contrib import layers as layers_lib
|
||||
from tensorflow.contrib.bayesflow.python.ops import entropy_impl as entropy_lib
|
||||
from tensorflow.contrib.bayesflow.python.ops import entropy_impl as entropy
|
||||
from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_diag_lib
|
||||
from tensorflow.contrib.distributions.python.ops import mvn_tril as mvn_tril_lib
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops.distributions import kullback_leibler as kullback_leibler_lib
|
||||
from tensorflow.python.ops.distributions import normal as normal_lib
|
||||
from tensorflow.python.ops.distributions import util as distribution_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
distributions = distributions_lib
|
||||
layers = layers_lib
|
||||
entropy = entropy_lib
|
||||
|
||||
|
||||
class NormalNoEntropy(distributions.Normal): # pylint: disable=no-init
|
||||
class NormalNoEntropy(normal_lib.Normal): # pylint: disable=no-init
|
||||
"""Normal distribution without a `.entropy` method."""
|
||||
|
||||
def entropy(self):
|
||||
@ -81,10 +83,10 @@ class ElboRatioTest(test.TestCase):
|
||||
n_samples = 5000
|
||||
|
||||
with self.test_session():
|
||||
q = distributions.MultivariateNormalDiag(
|
||||
q = mvn_diag_lib.MultivariateNormalDiag(
|
||||
loc=self._rng.rand(*vector_shape),
|
||||
scale_diag=self._rng.rand(*vector_shape))
|
||||
p = distributions.MultivariateNormalDiag(
|
||||
p = mvn_diag_lib.MultivariateNormalDiag(
|
||||
loc=self._rng.rand(*vector_shape),
|
||||
scale_diag=self._rng.rand(*vector_shape))
|
||||
|
||||
@ -95,7 +97,7 @@ class ElboRatioTest(test.TestCase):
|
||||
n=n_samples,
|
||||
form=entropy.ELBOForms.sample,
|
||||
seed=42)
|
||||
actual_kl = distributions.kl_divergence(q, p)
|
||||
actual_kl = kullback_leibler_lib.kl_divergence(q, p)
|
||||
|
||||
# Relative tolerance (rtol) chosen 2 times as large as minimim needed to
|
||||
# pass.
|
||||
@ -109,10 +111,10 @@ class ElboRatioTest(test.TestCase):
|
||||
|
||||
vector_shape = (2, 3)
|
||||
with self.test_session():
|
||||
q = distributions.MultivariateNormalDiag(
|
||||
q = mvn_diag_lib.MultivariateNormalDiag(
|
||||
loc=self._rng.rand(*vector_shape),
|
||||
scale_diag=self._rng.rand(*vector_shape))
|
||||
p = distributions.MultivariateNormalDiag(
|
||||
p = mvn_diag_lib.MultivariateNormalDiag(
|
||||
loc=self._rng.rand(*vector_shape),
|
||||
scale_diag=self._rng.rand(*vector_shape))
|
||||
|
||||
@ -123,7 +125,7 @@ class ElboRatioTest(test.TestCase):
|
||||
n=n_samples,
|
||||
form=entropy.ELBOForms.analytic_entropy,
|
||||
seed=42)
|
||||
actual_kl = distributions.kl_divergence(q, p)
|
||||
actual_kl = kullback_leibler_lib.kl_divergence(q, p)
|
||||
|
||||
# Relative tolerance (rtol) chosen 2 times as large as minimim needed to
|
||||
# pass.
|
||||
@ -135,7 +137,7 @@ class ElboRatioTest(test.TestCase):
|
||||
|
||||
vector_shape = (2, 3)
|
||||
with self.test_session():
|
||||
q = distributions.MultivariateNormalDiag(
|
||||
q = mvn_diag_lib.MultivariateNormalDiag(
|
||||
loc=self._rng.rand(*vector_shape),
|
||||
scale_diag=self._rng.rand(*vector_shape))
|
||||
|
||||
@ -155,7 +157,7 @@ class EntropyShannonTest(test.TestCase):
|
||||
|
||||
def test_normal_entropy_default_form_uses_exact_entropy(self):
|
||||
with self.test_session():
|
||||
dist = distributions.Normal(loc=1.11, scale=2.22)
|
||||
dist = normal_lib.Normal(loc=1.11, scale=2.22)
|
||||
mc_entropy = entropy.entropy_shannon(dist, n=11)
|
||||
exact_entropy = dist.entropy()
|
||||
self.assertEqual(exact_entropy.get_shape(), mc_entropy.get_shape())
|
||||
@ -163,7 +165,7 @@ class EntropyShannonTest(test.TestCase):
|
||||
|
||||
def test_normal_entropy_analytic_form_uses_exact_entropy(self):
|
||||
with self.test_session():
|
||||
dist = distributions.Normal(loc=1.11, scale=2.22)
|
||||
dist = normal_lib.Normal(loc=1.11, scale=2.22)
|
||||
mc_entropy = entropy.entropy_shannon(
|
||||
dist, form=entropy.ELBOForms.analytic_entropy)
|
||||
exact_entropy = dist.entropy()
|
||||
@ -173,7 +175,7 @@ class EntropyShannonTest(test.TestCase):
|
||||
def test_normal_entropy_sample_form_gets_approximate_answer(self):
|
||||
# Tested by showing we get a good answer that is not exact.
|
||||
with self.test_session():
|
||||
dist = distributions.Normal(loc=1.11, scale=2.22)
|
||||
dist = normal_lib.Normal(loc=1.11, scale=2.22)
|
||||
mc_entropy = entropy.entropy_shannon(
|
||||
dist, n=1000, form=entropy.ELBOForms.sample, seed=0)
|
||||
exact_entropy = dist.entropy()
|
||||
@ -193,7 +195,7 @@ class EntropyShannonTest(test.TestCase):
|
||||
# NormalNoEntropy is like a Normal, but does not have .entropy method, so
|
||||
# we are forced to fall back on sample entropy.
|
||||
dist_no_entropy = NormalNoEntropy(loc=1.11, scale=2.22)
|
||||
dist_yes_entropy = distributions.Normal(loc=1.11, scale=2.22)
|
||||
dist_yes_entropy = normal_lib.Normal(loc=1.11, scale=2.22)
|
||||
|
||||
mc_entropy = entropy.entropy_shannon(
|
||||
dist_no_entropy, n=1000, form=entropy.ELBOForms.sample, seed=0)
|
||||
@ -222,15 +224,16 @@ class RenyiRatioTest(test.TestCase):
|
||||
mu_true = np.array([1.0, -1.0], dtype=np.float64)
|
||||
chol_true = np.array([[2.0, 0.0], [0.5, 1.0]], dtype=np.float64)
|
||||
with self.test_session() as sess:
|
||||
target = distributions.MultivariateNormalTriL(mu_true, chol_true)
|
||||
target = mvn_tril_lib.MultivariateNormalTriL(mu_true, chol_true)
|
||||
|
||||
# Set up q distribution by defining mean/covariance as Variables
|
||||
mu = variables.Variable(
|
||||
np.zeros(mu_true.shape), dtype=mu_true.dtype, name='mu')
|
||||
mat = variables.Variable(
|
||||
np.zeros(chol_true.shape), dtype=chol_true.dtype, name='mat')
|
||||
chol = distributions.matrix_diag_transform(mat, transform=nn_ops.softplus)
|
||||
q = distributions.MultivariateNormalTriL(mu, chol)
|
||||
chol = distribution_util.matrix_diag_transform(
|
||||
mat, transform=nn_ops.softplus)
|
||||
q = mvn_tril_lib.MultivariateNormalTriL(mu, chol)
|
||||
for alpha in [0.25, 0.75]:
|
||||
|
||||
negative_renyi_divergence = entropy.renyi_ratio(
|
||||
@ -262,7 +265,7 @@ class RenyiRatioTest(test.TestCase):
|
||||
n = 1000
|
||||
vector_shape = (2, 3)
|
||||
with self.test_session():
|
||||
q = distributions.MultivariateNormalDiag(
|
||||
q = mvn_diag_lib.MultivariateNormalDiag(
|
||||
loc=self._rng.rand(*vector_shape),
|
||||
scale_diag=self._rng.rand(*vector_shape))
|
||||
for alpha in [0.25, 0.75]:
|
||||
|
@ -36,27 +36,48 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":bijectors_py",
|
||||
"//tensorflow/contrib/framework:framework_py",
|
||||
"//tensorflow/contrib/learn",
|
||||
"//tensorflow/contrib/linalg:linalg_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:check_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:clip_ops",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:data_flow_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:init_ops",
|
||||
"//tensorflow/python:linalg_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:nn",
|
||||
"//tensorflow/python:nn_ops",
|
||||
"//tensorflow/python:random_ops",
|
||||
"//tensorflow/python:state_ops",
|
||||
"//tensorflow/python:tensor_util",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python/ops/distributions",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "estimator_test",
|
||||
size = "small",
|
||||
srcs = ["python/kernel_tests/estimator_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
"//tensorflow/contrib/learn",
|
||||
"//tensorflow/contrib/learn:head_test",
|
||||
"//tensorflow/python/ops/distributions",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:nn_ops",
|
||||
"//tensorflow/python:session",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "distribution_test",
|
||||
size = "small",
|
||||
|
@ -30,6 +30,7 @@ from tensorflow.contrib.distributions.python.ops.conditional_transformed_distrib
|
||||
from tensorflow.contrib.distributions.python.ops.deterministic import *
|
||||
from tensorflow.contrib.distributions.python.ops.distribution_util import matrix_diag_transform
|
||||
from tensorflow.contrib.distributions.python.ops.distribution_util import softplus_inverse
|
||||
from tensorflow.contrib.distributions.python.ops.estimator import *
|
||||
from tensorflow.contrib.distributions.python.ops.geometric import *
|
||||
from tensorflow.contrib.distributions.python.ops.inverse_gamma import *
|
||||
from tensorflow.contrib.distributions.python.ops.logistic import *
|
||||
@ -147,6 +148,7 @@ _allowed_symbols = [
|
||||
'percentile',
|
||||
'assign_exponential_moving_mean_variance',
|
||||
'exponential_moving_mean_variance',
|
||||
'estimator_head_distribution_regression',
|
||||
]
|
||||
|
||||
remove_undocumented(__name__, _allowed_symbols)
|
||||
|
@ -0,0 +1,114 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for estimator.py."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.contrib.distributions.python.ops import estimator as estimator_lib
|
||||
from tensorflow.contrib.learn.python.learn.estimators import constants
|
||||
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
|
||||
from tensorflow.contrib.learn.python.learn.estimators import model_fn
|
||||
from tensorflow.contrib.learn.python.learn.estimators.head_test import _assert_metrics
|
||||
from tensorflow.contrib.learn.python.learn.estimators.head_test import _assert_no_variables
|
||||
from tensorflow.contrib.learn.python.learn.estimators.head_test import _assert_summary_tags
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops.distributions import normal as normal_lib
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class EstimatorHeadDistributionRegressionTest(test.TestCase):
|
||||
|
||||
def _assert_output_alternatives(self, model_fn_ops):
|
||||
self.assertEquals({
|
||||
None: constants.ProblemType.LINEAR_REGRESSION
|
||||
}, {
|
||||
k: v[0] for k, v in six.iteritems(model_fn_ops.output_alternatives)
|
||||
})
|
||||
|
||||
def testNormalLocScaleLogits(self):
|
||||
# We will bias logits[..., 1] so that: logits[..., 1]=0 implies scale=1.
|
||||
scale_bias = np.log(np.expm1(1.))
|
||||
|
||||
def softplus(x):
|
||||
return np.log1p(np.exp(x))
|
||||
|
||||
def actual_loss(logits, labels):
|
||||
mu = actual_mean(logits)
|
||||
sigma = actual_stddev(logits)
|
||||
labels = np.squeeze(labels, -1)
|
||||
z = (labels - mu) / sigma
|
||||
loss = 0.5 * (z**2. + np.log(2. * np.pi)) + np.log(sigma)
|
||||
return loss.mean()
|
||||
|
||||
def actual_mean(logits):
|
||||
return logits[..., 0]
|
||||
|
||||
def actual_stddev(logits):
|
||||
return softplus(logits[..., 1] + scale_bias)
|
||||
|
||||
def make_distribution_fn(logits):
|
||||
return normal_lib.Normal(
|
||||
loc=logits[..., 0],
|
||||
scale=nn_ops.softplus(logits[..., 1] + scale_bias))
|
||||
|
||||
head = estimator_lib.estimator_head_distribution_regression(
|
||||
make_distribution_fn,
|
||||
logits_dimension=2)
|
||||
labels = np.float32([[-1.],
|
||||
[0.],
|
||||
[1.]])
|
||||
logits = np.float32([[0., -1],
|
||||
[1, 0.5],
|
||||
[-1, 1]])
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
# Convert to tensor so we can index into head.distributions.
|
||||
tflogits = ops.convert_to_tensor(logits, name="logits")
|
||||
model_fn_ops = head.create_model_fn_ops(
|
||||
{},
|
||||
labels=labels,
|
||||
mode=model_fn.ModeKeys.TRAIN,
|
||||
train_op_fn=head_lib.no_op_train_fn,
|
||||
logits=tflogits)
|
||||
self._assert_output_alternatives(model_fn_ops)
|
||||
_assert_summary_tags(self, ["loss"])
|
||||
_assert_no_variables(self)
|
||||
loss = actual_loss(logits, labels)
|
||||
_assert_metrics(self, loss, {"loss": loss}, model_fn_ops)
|
||||
|
||||
# Now we verify the underlying distribution was correctly constructed.
|
||||
expected_mean = logits[..., 0]
|
||||
self.assertAllClose(
|
||||
expected_mean,
|
||||
head.distribution(tflogits).mean().eval(),
|
||||
rtol=1e-6, atol=0.)
|
||||
|
||||
expected_stddev = softplus(logits[..., 1] + scale_bias)
|
||||
self.assertAllClose(
|
||||
expected_stddev,
|
||||
head.distribution(tflogits).stddev().eval(),
|
||||
rtol=1e-6, atol=0.)
|
||||
# Should have created only one distribution.
|
||||
self.assertEqual(1, len(head.distributions))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
185
tensorflow/contrib/distributions/python/ops/estimator.py
Normal file
185
tensorflow/contrib/distributions/python/ops/estimator.py
Normal file
@ -0,0 +1,185 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Functions to bridge `Distribution`s and `tf.contrib.learn.estimator` APIs."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.learn.python.learn.estimators.head import _compute_weighted_loss
|
||||
from tensorflow.contrib.learn.python.learn.estimators.head import _RegressionHead
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
|
||||
|
||||
__all__ = [
|
||||
"estimator_head_distribution_regression",
|
||||
]
|
||||
|
||||
|
||||
def estimator_head_distribution_regression(make_distribution_fn,
|
||||
label_dimension=1,
|
||||
logits_dimension=None,
|
||||
label_name=None,
|
||||
weight_column_name=None,
|
||||
enable_centered_bias=False,
|
||||
head_name=None):
|
||||
"""Creates a `Head` for regression under a generic distribution.
|
||||
|
||||
Args:
|
||||
make_distribution_fn: Python `callable` which returns a `tf.Distribution`
|
||||
instance created using only logits.
|
||||
label_dimension: Number of regression labels per example. This is the size
|
||||
of the last dimension of the labels `Tensor` (typically, this has shape
|
||||
`[batch_size, label_dimension]`).
|
||||
logits_dimension: Number of logits per example. This is the size of the last
|
||||
dimension of the logits `Tensor` (typically, this has shape
|
||||
`[batch_size, logits_dimension]`).
|
||||
Default value: `label_dimension`.
|
||||
label_name: Python `str`, name of the key in label `dict`. Can be `None` if
|
||||
label is a `Tensor` (single headed models).
|
||||
weight_column_name: Python `str` defining feature column name representing
|
||||
weights. It is used to down weight or boost examples during training. It
|
||||
will be multiplied by the loss of the example.
|
||||
enable_centered_bias: Python `bool`. If `True`, estimator will learn a
|
||||
centered bias variable for each class. Rest of the model structure learns
|
||||
the residual after centered bias.
|
||||
head_name: Python `str`, name of the head. Predictions, summary and metrics
|
||||
keys are suffixed by `"/" + head_name` and the default variable scope is
|
||||
`head_name`.
|
||||
|
||||
Returns:
|
||||
An instance of `Head` for generic regression.
|
||||
"""
|
||||
return _DistributionRegressionHead(
|
||||
make_distribution_fn=make_distribution_fn,
|
||||
label_dimension=label_dimension,
|
||||
logits_dimension=logits_dimension,
|
||||
label_name=label_name,
|
||||
weight_column_name=weight_column_name,
|
||||
enable_centered_bias=enable_centered_bias,
|
||||
head_name=head_name)
|
||||
|
||||
|
||||
class _DistributionRegressionHead(_RegressionHead):
|
||||
"""Creates a _RegressionHead instance from an arbitray `Distribution`."""
|
||||
|
||||
def __init__(self,
|
||||
make_distribution_fn,
|
||||
label_dimension,
|
||||
logits_dimension=None,
|
||||
label_name=None,
|
||||
weight_column_name=None,
|
||||
enable_centered_bias=False,
|
||||
head_name=None):
|
||||
"""`Head` for regression.
|
||||
|
||||
Args:
|
||||
make_distribution_fn: Python `callable` which returns a `tf.Distribution`
|
||||
instance created using only logits.
|
||||
label_dimension: Number of regression labels per example. This is the
|
||||
size of the last dimension of the labels `Tensor` (typically, this has
|
||||
shape `[batch_size, label_dimension]`).
|
||||
logits_dimension: Number of logits per example. This is the size of the
|
||||
last dimension of the logits `Tensor` (typically, this has shape
|
||||
`[batch_size, logits_dimension]`).
|
||||
Default value: `label_dimension`.
|
||||
label_name: Python `str`, name of the key in label `dict`. Can be `None`
|
||||
if label is a tensor (single headed models).
|
||||
weight_column_name: Python `str` defining feature column name representing
|
||||
weights. It is used to down weight or boost examples during training. It
|
||||
will be multiplied by the loss of the example.
|
||||
enable_centered_bias: Python `bool`. If `True`, estimator will learn a
|
||||
centered bias variable for each class. Rest of the model structure
|
||||
learns the residual after centered bias.
|
||||
head_name: Python `str`, name of the head. Predictions, summary and
|
||||
metrics keys are suffixed by `"/" + head_name` and the default variable
|
||||
scope is `head_name`.
|
||||
|
||||
Raises:
|
||||
TypeError: if `make_distribution_fn` is not `callable`.
|
||||
"""
|
||||
if not callable(make_distribution_fn):
|
||||
raise TypeError("`make_distribution_fn` must be a callable function.")
|
||||
|
||||
self._distributions = {}
|
||||
self._make_distribution_fn = make_distribution_fn
|
||||
|
||||
def static_value(x):
|
||||
"""Returns the static value of a `Tensor` or `None`."""
|
||||
return tensor_util.constant_value(ops.convert_to_tensor(x))
|
||||
|
||||
def concat_vectors(*args):
|
||||
"""Concatenates input vectors, statically if possible."""
|
||||
args_ = [static_value(x) for x in args]
|
||||
if any(vec is None for vec in args_):
|
||||
return array_ops.concat(args, axis=0)
|
||||
return [val for vec in args_ for val in vec]
|
||||
|
||||
def loss_fn(labels, logits, weights=None):
|
||||
"""Returns the loss of using `logits` to predict `labels`."""
|
||||
d = self.distribution(logits)
|
||||
labels_batch_shape = labels.shape.with_rank_at_least(1)[:-1]
|
||||
labels_batch_shape = (
|
||||
labels_batch_shape.as_list() if labels_batch_shape.is_fully_defined()
|
||||
else array_ops.shape(labels)[:-1])
|
||||
labels = array_ops.reshape(
|
||||
labels,
|
||||
shape=concat_vectors(labels_batch_shape, d.event_shape_tensor()))
|
||||
return _compute_weighted_loss(
|
||||
loss_unweighted=-d.log_prob(labels),
|
||||
weight=weights)
|
||||
|
||||
def link_fn(logits):
|
||||
"""Returns the inverse link function at `logits`."""
|
||||
# Note: What the API calls a "link function" is really the inverse-link
|
||||
# function, i.e., the "mean".
|
||||
d = self.distribution(logits)
|
||||
return d.mean()
|
||||
|
||||
super(_DistributionRegressionHead, self).__init__(
|
||||
label_dimension=label_dimension,
|
||||
loss_fn=loss_fn,
|
||||
link_fn=link_fn,
|
||||
logits_dimension=logits_dimension,
|
||||
label_name=label_name,
|
||||
weight_column_name=weight_column_name,
|
||||
enable_centered_bias=enable_centered_bias,
|
||||
head_name=head_name)
|
||||
|
||||
@property
|
||||
def distributions(self):
|
||||
"""Returns all distributions created by `DistributionRegressionHead`."""
|
||||
return self._distributions
|
||||
|
||||
def distribution(self, logits, name=None):
|
||||
"""Retrieves a distribution instance, parameterized by `logits`.
|
||||
|
||||
Args:
|
||||
logits: `float`-like `Tensor` representing the parameters of the
|
||||
underlying distribution.
|
||||
name: The Python `str` name to given to this op.
|
||||
Default value: "distribution".
|
||||
|
||||
Returns:
|
||||
distribution: `tf.Distribution` instance parameterized by `logits`.
|
||||
"""
|
||||
with ops.name_scope(name, "distribution", [logits]):
|
||||
d = self._distributions.get(logits, None)
|
||||
if d is None:
|
||||
d = self._make_distribution_fn(logits)
|
||||
self._distributions[logits] = d
|
||||
return d
|
@ -19,8 +19,8 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib import linalg
|
||||
from tensorflow.contrib.distributions.python.ops import bijectors
|
||||
from tensorflow.contrib.distributions.python.ops import distribution_util
|
||||
from tensorflow.contrib.distributions.python.ops.bijectors import AffineLinearOperator
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
@ -189,7 +189,7 @@ class MultivariateNormalLinearOperator(
|
||||
distribution=normal.Normal(
|
||||
loc=array_ops.zeros([], dtype=scale.dtype),
|
||||
scale=array_ops.ones([], dtype=scale.dtype)),
|
||||
bijector=bijectors.AffineLinearOperator(
|
||||
bijector=AffineLinearOperator(
|
||||
shift=loc, scale=scale, validate_args=validate_args),
|
||||
batch_shape=batch_shape,
|
||||
event_shape=event_shape,
|
||||
|
@ -1,3 +1,4 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -31,13 +32,16 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.layers import base
|
||||
from tensorflow.python.layers import convolutional as convolutional_layers
|
||||
from tensorflow.python.layers import core as core_layers
|
||||
from tensorflow.python.layers import normalization as normalization_layers
|
||||
from tensorflow.python.layers import normalization as normalization_layers
|
||||
from tensorflow.python.layers import pooling as pooling_layers
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
@ -1281,7 +1285,7 @@ def convolution3d_transpose(
|
||||
trainable=True,
|
||||
scope=None):
|
||||
"""Adds a convolution3d_transpose with an optional batch normalization layer.
|
||||
|
||||
|
||||
The function creates a variable called `weights`, representing the
|
||||
kernel, that is convolved with the input. If `batch_norm_params` is `None`, a
|
||||
second variable called 'biases' is added to the result of the operation.
|
||||
@ -1808,6 +1812,300 @@ def layer_norm(inputs,
|
||||
outputs)
|
||||
|
||||
|
||||
class GDN(base.Layer):
|
||||
"""Generalized divisive normalization layer.
|
||||
|
||||
Based on the papers:
|
||||
|
||||
"Density Modeling of Images using a Generalized Normalization
|
||||
Transformation"
|
||||
Johannes Ballé, Valero Laparra, Eero P. Simoncelli
|
||||
https://arxiv.org/abs/1511.06281
|
||||
|
||||
"End-to-end Optimized Image Compression"
|
||||
Johannes Ballé, Valero Laparra, Eero P. Simoncelli
|
||||
https://arxiv.org/abs/1611.01704
|
||||
|
||||
Implements an activation function that is essentially a multivariate
|
||||
generalization of a particular sigmoid-type function:
|
||||
|
||||
y[i] = x[i] / sqrt(beta[i] + sum_j(gamma[j, i] * x[j]))
|
||||
|
||||
where i and j run over channels. This implementation never sums across spatial
|
||||
dimensions. It is similar to local response normalization, but more powerful,
|
||||
as beta and gamma are trainable parameters.
|
||||
|
||||
Arguments:
|
||||
inverse: If False (default), compute GDN response. If True, compute IGDN
|
||||
response (one step of fixed point iteration to invert GDN; the division
|
||||
is replaced by multiplication).
|
||||
beta_min: Lower bound for beta, to prevent numerical error from causing
|
||||
square root of zero or negative values.
|
||||
gamma_init: The gamma matrix will be initialized as the identity matrix
|
||||
multiplied with this value. If set to zero, the layer is effectively
|
||||
initialized to the identity operation, since beta is initialized as one.
|
||||
A good default setting is somewhere between 0 and 0.5.
|
||||
reparam_offset: Offset added to the reparameterization of beta and gamma.
|
||||
The reparameterization of beta and gamma as their square roots lets the
|
||||
training slow down when their values are close to zero, which is desirable
|
||||
as small values in the denominator can lead to a situation where gradient
|
||||
noise on beta/gamma leads to extreme amounts of noise in the GDN
|
||||
activations. However, without the offset, we would get zero gradients if
|
||||
any elements of beta or gamma were exactly zero, and thus the training
|
||||
could get stuck. To prevent this, we add this small constant. The default
|
||||
value was empirically determined as a good starting point. Making it
|
||||
bigger potentially leads to more gradient noise on the activations, making
|
||||
it too small may lead to numerical precision issues.
|
||||
data_format: Format of input tensor. Currently supports 'channels_first' and
|
||||
'channels_last'.
|
||||
trainable: Boolean, if `True` also add variables to the graph collection
|
||||
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
|
||||
name: String, the name of the layer. Layers with the same name will
|
||||
share weights, but to avoid mistakes we require reuse=True in such cases.
|
||||
reuse: Boolean, whether to reuse the weights of a previous layer
|
||||
by the same name.
|
||||
|
||||
Properties:
|
||||
inverse: Boolean, whether GDN is computed (True) or IGDN (False).
|
||||
data_format: Format of input tensor. Currently supports 'channels_first' and
|
||||
'channels_last'.
|
||||
beta: The beta parameter as defined above (1D TensorFlow tensor).
|
||||
gamma: The gamma parameter as defined above (2D TensorFlow tensor).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
inverse=False,
|
||||
beta_min=1e-6,
|
||||
gamma_init=.1,
|
||||
reparam_offset=2 ** -18,
|
||||
data_format='channels_last',
|
||||
trainable=True,
|
||||
name=None,
|
||||
**kwargs):
|
||||
super(GDN, self).__init__(trainable=trainable, name=name, **kwargs)
|
||||
self.inverse = inverse
|
||||
self._beta_min = beta_min
|
||||
self._gamma_init = gamma_init
|
||||
self._reparam_offset = reparam_offset
|
||||
self.data_format = data_format
|
||||
self._channel_axis() # trigger ValueError early
|
||||
self.input_spec = base.InputSpec(min_ndim=3, max_ndim=5)
|
||||
|
||||
def _channel_axis(self):
|
||||
try:
|
||||
return {'channels_first': 1, 'channels_last': -1}[self.data_format]
|
||||
except KeyError:
|
||||
raise ValueError('Unsupported `data_format` for GDN layer: {}.'.format(
|
||||
self.data_format))
|
||||
|
||||
@staticmethod
|
||||
def _lower_bound(inputs, bound, name=None):
|
||||
"""Same as tf.maximum, but with helpful gradient for inputs < bound.
|
||||
|
||||
The gradient is overwritten so that it is passed through if the input is not
|
||||
hitting the bound. If it is, only gradients that push `inputs` higher than
|
||||
the bound are passed through. No gradients are passed through to the bound.
|
||||
|
||||
Args:
|
||||
inputs: input tensor
|
||||
bound: lower bound for the input tensor
|
||||
name: name for this op
|
||||
|
||||
Returns:
|
||||
tf.maximum(inputs, bound)
|
||||
"""
|
||||
with ops.name_scope(name, 'GDNLowerBound', [inputs, bound]) as scope:
|
||||
inputs = ops.convert_to_tensor(inputs, name='inputs')
|
||||
bound = ops.convert_to_tensor(bound, name='bound')
|
||||
with ops.get_default_graph().gradient_override_map(
|
||||
{'Maximum': 'GDNLowerBound'}):
|
||||
return math_ops.maximum(inputs, bound, name=scope)
|
||||
|
||||
@ops.RegisterGradient('GDNLowerBound')
|
||||
@staticmethod
|
||||
def _lower_bound_grad(op, grad):
|
||||
"""Gradient for `_lower_bound`.
|
||||
|
||||
Args:
|
||||
op: the tensorflow op for which to calculate a gradient
|
||||
grad: gradient with respect to the output of the op
|
||||
|
||||
Returns:
|
||||
gradients with respect to the inputs of the op
|
||||
"""
|
||||
inputs = op.inputs[0]
|
||||
bound = op.inputs[1]
|
||||
pass_through_if = math_ops.logical_or(inputs >= bound, grad < 0)
|
||||
return [math_ops.cast(pass_through_if, grad.dtype) * grad, None]
|
||||
|
||||
def build(self, input_shape):
|
||||
channel_axis = self._channel_axis()
|
||||
input_shape = tensor_shape.TensorShape(input_shape)
|
||||
num_channels = input_shape[channel_axis].value
|
||||
if num_channels is None:
|
||||
raise ValueError('The channel dimension of the inputs to `GDN` '
|
||||
'must be defined.')
|
||||
self._input_rank = input_shape.ndims
|
||||
self.input_spec = base.InputSpec(ndim=input_shape.ndims,
|
||||
axes={channel_axis: num_channels})
|
||||
|
||||
pedestal = array_ops.constant(self._reparam_offset ** 2, dtype=self.dtype)
|
||||
beta_bound = array_ops.constant(
|
||||
(self._beta_min + self._reparam_offset ** 2) ** .5, dtype=self.dtype)
|
||||
gamma_bound = array_ops.constant(self._reparam_offset, dtype=self.dtype)
|
||||
|
||||
def beta_initializer(shape, dtype=None, partition_info=None):
|
||||
del partition_info # unused
|
||||
return math_ops.sqrt(array_ops.ones(shape, dtype=dtype) + pedestal)
|
||||
|
||||
def gamma_initializer(shape, dtype=None, partition_info=None):
|
||||
del partition_info # unused
|
||||
assert len(shape) == 2
|
||||
assert shape[0] == shape[1]
|
||||
eye = linalg_ops.eye(shape[0], dtype=dtype)
|
||||
return math_ops.sqrt(self._gamma_init * eye + pedestal)
|
||||
|
||||
beta = self.add_variable('reparam_beta',
|
||||
shape=[num_channels],
|
||||
initializer=beta_initializer,
|
||||
dtype=self.dtype,
|
||||
trainable=True)
|
||||
beta = self._lower_bound(beta, beta_bound)
|
||||
self.beta = math_ops.square(beta) - pedestal
|
||||
|
||||
gamma = self.add_variable('reparam_gamma',
|
||||
shape=[num_channels, num_channels],
|
||||
initializer=gamma_initializer,
|
||||
dtype=self.dtype,
|
||||
trainable=True)
|
||||
gamma = self._lower_bound(gamma, gamma_bound)
|
||||
self.gamma = math_ops.square(gamma) - pedestal
|
||||
|
||||
self.built = True
|
||||
|
||||
def call(self, inputs):
|
||||
inputs = ops.convert_to_tensor(inputs, dtype=self.dtype)
|
||||
ndim = self._input_rank
|
||||
|
||||
shape = self.gamma.get_shape().as_list()
|
||||
gamma = array_ops.reshape(self.gamma, (ndim - 2) * [1] + shape)
|
||||
|
||||
# Compute normalization pool.
|
||||
if self.data_format == 'channels_first':
|
||||
norm_pool = nn.convolution(math_ops.square(inputs), gamma, 'VALID',
|
||||
data_format='NC' + 'DHW'[-(ndim - 2):])
|
||||
if ndim == 3:
|
||||
norm_pool = array_ops.expand_dims(norm_pool, 2)
|
||||
norm_pool = nn.bias_add(norm_pool, self.beta, data_format='NCHW')
|
||||
norm_pool = array_ops.squeeze(norm_pool, [2])
|
||||
elif ndim == 5:
|
||||
shape = array_ops.shape(norm_pool)
|
||||
norm_pool = array_ops.reshape(norm_pool, shape[:3] + [-1])
|
||||
norm_pool = nn.bias_add(norm_pool, self.beta, data_format='NCHW')
|
||||
norm_pool = array_ops.reshape(norm_pool, shape)
|
||||
else: # ndim == 4
|
||||
norm_pool = nn.bias_add(norm_pool, self.beta, data_format='NCHW')
|
||||
else: # channels_last
|
||||
norm_pool = nn.convolution(math_ops.square(inputs), gamma, 'VALID')
|
||||
norm_pool = nn.bias_add(norm_pool, self.beta, data_format='NHWC')
|
||||
norm_pool = math_ops.sqrt(norm_pool)
|
||||
|
||||
if self.inverse:
|
||||
outputs = inputs * norm_pool
|
||||
else:
|
||||
outputs = inputs / norm_pool
|
||||
outputs.set_shape(inputs.get_shape())
|
||||
return outputs
|
||||
|
||||
def _compute_output_shape(self, input_shape):
|
||||
channel_axis = self._channel_axis()
|
||||
input_shape = tensor_shape.TensorShape(input_shape)
|
||||
if not 3 <= input_shape.ndim <= 5:
|
||||
raise ValueError('`input_shape` must be of rank 3 to 5, inclusive.')
|
||||
if input_shape[channel_axis].value is None:
|
||||
raise ValueError(
|
||||
'The channel dimension of `input_shape` must be defined.')
|
||||
return input_shape
|
||||
|
||||
|
||||
def gdn(inputs,
|
||||
inverse=False,
|
||||
beta_min=1e-6,
|
||||
gamma_init=.1,
|
||||
reparam_offset=2 ** -18,
|
||||
data_format='channels_last',
|
||||
trainable=True,
|
||||
name=None,
|
||||
reuse=None):
|
||||
"""Functional interface for GDN layer.
|
||||
|
||||
Based on the papers:
|
||||
|
||||
"Density Modeling of Images using a Generalized Normalization
|
||||
Transformation"
|
||||
Johannes Ballé, Valero Laparra, Eero P. Simoncelli
|
||||
https://arxiv.org/abs/1511.06281
|
||||
|
||||
"End-to-end Optimized Image Compression"
|
||||
Johannes Ballé, Valero Laparra, Eero P. Simoncelli
|
||||
https://arxiv.org/abs/1611.01704
|
||||
|
||||
Implements an activation function that is essentially a multivariate
|
||||
generalization of a particular sigmoid-type function:
|
||||
|
||||
y[i] = x[i] / sqrt(beta[i] + sum_j(gamma[j, i] * x[j]))
|
||||
|
||||
where i and j run over channels. This implementation never sums across spatial
|
||||
dimensions. It is similar to local response normalization, but more powerful,
|
||||
as beta and gamma are trainable parameters.
|
||||
|
||||
Arguments:
|
||||
inputs: Tensor input.
|
||||
inverse: If False (default), compute GDN response. If True, compute IGDN
|
||||
response (one step of fixed point iteration to invert GDN; the division
|
||||
is replaced by multiplication).
|
||||
beta_min: Lower bound for beta, to prevent numerical error from causing
|
||||
square root of zero or negative values.
|
||||
gamma_init: The gamma matrix will be initialized as the identity matrix
|
||||
multiplied with this value. If set to zero, the layer is effectively
|
||||
initialized to the identity operation, since beta is initialized as one.
|
||||
A good default setting is somewhere between 0 and 0.5.
|
||||
reparam_offset: Offset added to the reparameterization of beta and gamma.
|
||||
The reparameterization of beta and gamma as their square roots lets the
|
||||
training slow down when their values are close to zero, which is desirable
|
||||
as small values in the denominator can lead to a situation where gradient
|
||||
noise on beta/gamma leads to extreme amounts of noise in the GDN
|
||||
activations. However, without the offset, we would get zero gradients if
|
||||
any elements of beta or gamma were exactly zero, and thus the training
|
||||
could get stuck. To prevent this, we add this small constant. The default
|
||||
value was empirically determined as a good starting point. Making it
|
||||
bigger potentially leads to more gradient noise on the activations, making
|
||||
it too small may lead to numerical precision issues.
|
||||
data_format: Format of input tensor. Currently supports 'channels_first' and
|
||||
'channels_last'.
|
||||
trainable: Boolean, if `True` also add variables to the graph collection
|
||||
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
|
||||
name: String, the name of the layer. Layers with the same name will
|
||||
share weights, but to avoid mistakes we require reuse=True in such cases.
|
||||
reuse: Boolean, whether to reuse the weights of a previous layer
|
||||
by the same name.
|
||||
|
||||
Returns:
|
||||
Output tensor.
|
||||
"""
|
||||
layer = GDN(inverse=inverse,
|
||||
beta_min=beta_min,
|
||||
gamma_init=gamma_init,
|
||||
reparam_offset=reparam_offset,
|
||||
data_format=data_format,
|
||||
trainable=trainable,
|
||||
name=name,
|
||||
dtype=inputs.dtype.base_dtype,
|
||||
_scope=name,
|
||||
_reuse=reuse)
|
||||
return layer.apply(inputs)
|
||||
|
||||
|
||||
@add_arg_scope
|
||||
def max_pool2d(inputs,
|
||||
kernel_size,
|
||||
|
@ -2772,6 +2772,56 @@ class LayerNormTest(test.TestCase):
|
||||
self.doOutputTest((1, 100, 100, 1))
|
||||
|
||||
|
||||
class GDNTest(test.TestCase):
|
||||
|
||||
def _runGDN(self, x, shape, inverse, data_format):
|
||||
inputs = array_ops.placeholder(dtypes.float32, shape)
|
||||
outputs = _layers.gdn(inputs, inverse=inverse, data_format=data_format)
|
||||
with self.test_session() as sess:
|
||||
variables_lib.global_variables_initializer().run()
|
||||
y, = sess.run([outputs], {inputs: x})
|
||||
return y
|
||||
|
||||
def testInvalidDataFormat(self):
|
||||
x = np.random.uniform(size=(1, 2, 3, 4))
|
||||
with self.assertRaises(ValueError):
|
||||
self._runGDN(x, x.shape, False, 'NHWC')
|
||||
|
||||
def testUnknownDim(self):
|
||||
x = np.random.uniform(size=(1, 2, 3, 4))
|
||||
with self.assertRaises(ValueError):
|
||||
self._runGDN(x, 4 * [None], False, 'channels_last')
|
||||
|
||||
def testChannelsLast(self):
|
||||
for ndim in [3, 4, 5]:
|
||||
x = np.random.uniform(size=(1, 2, 3, 4)[:ndim])
|
||||
y = self._runGDN(x, x.shape, False, 'channels_last')
|
||||
self.assertEqual(x.shape, y.shape)
|
||||
self.assertAllClose(y, x / np.sqrt(1 + .1 * (x ** 2)), rtol=0, atol=1e-6)
|
||||
|
||||
def testChannelsFirst(self):
|
||||
# `bias_add` doesn't support NCHW on CPU.
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
for ndim in [3, 4, 5]:
|
||||
x = np.random.uniform(size=(4, 3, 2, 1)[:ndim])
|
||||
y = self._runGDN(x, x.shape, False, 'channels_first')
|
||||
self.assertEqual(x.shape, y.shape)
|
||||
self.assertAllClose(
|
||||
y, x / np.sqrt(1 + .1 * (x ** 2)), rtol=0, atol=1e-6)
|
||||
|
||||
def testWrongDims(self):
|
||||
for ndim in [1, 2, 6]:
|
||||
x = np.random.uniform(size=(1, 2, 3, 4, 3, 2)[:ndim])
|
||||
with self.assertRaises(ValueError):
|
||||
self._runGDN(x, x.shape, False, 'channels_last')
|
||||
|
||||
def testIGDN(self):
|
||||
x = np.random.uniform(size=(1, 2, 3, 4))
|
||||
y = self._runGDN(x, x.shape, True, 'channels_last')
|
||||
self.assertEqual(x.shape, y.shape)
|
||||
self.assertAllClose(y, x * np.sqrt(1 + .1 * (x ** 2)), rtol=0, atol=1e-6)
|
||||
|
||||
|
||||
class MaxPool2DTest(test.TestCase):
|
||||
|
||||
def testInvalidDataFormat(self):
|
||||
|
@ -665,6 +665,7 @@ class _RegressionHead(_SingleHead):
|
||||
label_dimension,
|
||||
loss_fn,
|
||||
link_fn,
|
||||
logits_dimension=None,
|
||||
label_name=None,
|
||||
weight_column_name=None,
|
||||
enable_centered_bias=False,
|
||||
@ -677,6 +678,10 @@ class _RegressionHead(_SingleHead):
|
||||
shape `[batch_size, label_dimension]`).
|
||||
loss_fn: Loss function, takes logits and labels and returns loss.
|
||||
link_fn: Link function, takes a logits tensor and returns the output.
|
||||
logits_dimension: Number of logits per example. This is the
|
||||
size of the last dimension of the logits `Tensor` (typically, this has
|
||||
shape `[batch_size, label_dimension]`).
|
||||
Default value: `label_dimension`.
|
||||
label_name: String, name of the key in label dict. Can be null if label
|
||||
is a tensor (single headed models).
|
||||
weight_column_name: A string defining feature column name representing
|
||||
@ -691,7 +696,8 @@ class _RegressionHead(_SingleHead):
|
||||
"""
|
||||
super(_RegressionHead, self).__init__(
|
||||
problem_type=constants.ProblemType.LINEAR_REGRESSION,
|
||||
logits_dimension=label_dimension,
|
||||
logits_dimension=(logits_dimension if logits_dimension is not None
|
||||
else label_dimension),
|
||||
label_name=label_name,
|
||||
weight_column_name=weight_column_name,
|
||||
head_name=head_name)
|
||||
|
@ -103,7 +103,7 @@ def stft(signals, frame_length, frame_step, fft_length=None,
|
||||
def inverse_stft(stfts,
|
||||
frame_length,
|
||||
frame_step,
|
||||
fft_length,
|
||||
fft_length=None,
|
||||
window_fn=functools.partial(window_ops.hann_window,
|
||||
periodic=True),
|
||||
name=None):
|
||||
@ -118,7 +118,8 @@ def inverse_stft(stfts,
|
||||
frame_length: An integer scalar `Tensor`. The window length in samples.
|
||||
frame_step: An integer scalar `Tensor`. The number of samples to step.
|
||||
fft_length: An integer scalar `Tensor`. The size of the FFT that produced
|
||||
`stfts`.
|
||||
`stfts`. If not provided, uses the smallest power of 2 enclosing
|
||||
`frame_length`.
|
||||
window_fn: A callable that takes a window length and a `dtype` keyword
|
||||
argument and returns a `[window_length]` `Tensor` of samples in the
|
||||
provided datatype. If set to `None`, no windowing is used.
|
||||
@ -130,7 +131,8 @@ def inverse_stft(stfts,
|
||||
|
||||
Raises:
|
||||
ValueError: If `stfts` is not at least rank 2, `frame_length` is not scalar,
|
||||
`frame_step` is not scalar, or `fft_length` is not scalar.
|
||||
`frame_step` is not scalar, or `fft_length` is not scalar, or
|
||||
`frame_length` is greater than `fft_length`.
|
||||
|
||||
[stft]: https://en.wikipedia.org/wiki/Short-time_Fourier_transform
|
||||
"""
|
||||
@ -141,8 +143,21 @@ def inverse_stft(stfts,
|
||||
frame_length.shape.assert_has_rank(0)
|
||||
frame_step = ops.convert_to_tensor(frame_step, name='frame_step')
|
||||
frame_step.shape.assert_has_rank(0)
|
||||
fft_length = ops.convert_to_tensor(fft_length, name='fft_length')
|
||||
fft_length.shape.assert_has_rank(0)
|
||||
if fft_length is None:
|
||||
fft_length = _enclosing_power_of_two(frame_length)
|
||||
else:
|
||||
fft_length = ops.convert_to_tensor(fft_length, name='fft_length')
|
||||
fft_length.shape.assert_has_rank(0)
|
||||
|
||||
frame_length_static = tensor_util.constant_value(
|
||||
frame_length)
|
||||
fft_length_static = tensor_util.constant_value(fft_length)
|
||||
if (frame_length_static is not None and fft_length_static is not None and
|
||||
frame_length_static > fft_length_static):
|
||||
raise ValueError('frame_length (%d) may not be larger than '
|
||||
'fft_length (%d)' % (frame_length_static,
|
||||
fft_length_static))
|
||||
|
||||
real_frames = spectral_ops.irfft(stfts, [fft_length])[..., :frame_length]
|
||||
|
||||
# Optionally window and overlap-add the inner 2 dimensions of real_frames
|
||||
|
@ -98,14 +98,16 @@ class DirectSession : public Session {
|
||||
::tensorflow::Status ListDevices(
|
||||
std::vector<DeviceAttributes>* response) override;
|
||||
::tensorflow::Status Close() override;
|
||||
::tensorflow::Status LocalDeviceManager(const DeviceMgr** output) override {
|
||||
*output = device_mgr_.get();
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
void ExportCostModels(CostModelManager::CostModelMap* cost_models) {
|
||||
cost_model_manager_.ExportCostModels(cost_models);
|
||||
}
|
||||
|
||||
private:
|
||||
typedef DirectSession ME;
|
||||
|
||||
// We create one executor and its dependent library runtime for
|
||||
// every partition.
|
||||
struct PerPartitionExecutorsAndLib {
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
@ -1248,6 +1249,16 @@ TEST(DirectSessionTest, TestDirectSessionReset) {
|
||||
EXPECT_EQ("Cancelled: Session has been closed.", s.ToString());
|
||||
}
|
||||
|
||||
TEST(DirectSessionTest, LocalDeviceManager) {
|
||||
SessionOptions options;
|
||||
std::unique_ptr<Session> session(NewSession(options));
|
||||
|
||||
const DeviceMgr* mgr = nullptr;
|
||||
TF_ASSERT_OK(session->LocalDeviceManager(&mgr));
|
||||
ASSERT_TRUE(mgr != nullptr);
|
||||
EXPECT_GT(mgr->ListDevices().size(), 0);
|
||||
}
|
||||
|
||||
// A simple benchmark for the overhead of `DirectSession::Run()` calls
|
||||
// with varying numbers of feeds/fetches.
|
||||
void FeedFetchBenchmarkHelper(int num_feeds, int iters) {
|
||||
|
@ -260,6 +260,8 @@ class CallOp : public AsyncOpKernel {
|
||||
done);
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.step_id = ctx->step_id();
|
||||
opts.rendezvous = ctx->rendezvous();
|
||||
opts.cancellation_manager = ctx->cancellation_manager();
|
||||
opts.step_container = ctx->step_container();
|
||||
opts.stats_collector = ctx->stats_collector();
|
||||
opts.runner = ctx->runner();
|
||||
@ -545,23 +547,18 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
|
||||
Executor::Args exec_args;
|
||||
// Inherit the step_id from the caller.
|
||||
exec_args.step_id = opts.step_id;
|
||||
exec_args.step_container = opts.step_container;
|
||||
|
||||
exec_args.rendezvous = opts.rendezvous;
|
||||
exec_args.stats_collector = opts.stats_collector;
|
||||
exec_args.call_frame = frame;
|
||||
exec_args.cancellation_manager = opts.cancellation_manager;
|
||||
exec_args.step_container = opts.step_container;
|
||||
exec_args.runner = *opts.runner;
|
||||
// TODO(zhifengc): we can avoid creating rendez here if we know
|
||||
// there is no send/recv nodes in the graph.
|
||||
auto* rendez = new IntraProcessRendezvous(device_mgr_);
|
||||
exec_args.rendezvous = rendez;
|
||||
item->exec->RunAsync(
|
||||
// Executor args
|
||||
exec_args,
|
||||
// Done callback.
|
||||
[item, frame, rets, rendez, done](const Status& status) {
|
||||
[item, frame, rets, done](const Status& status) {
|
||||
item->Unref();
|
||||
rendez->Unref();
|
||||
Status s = status;
|
||||
if (s.ok()) {
|
||||
s = frame->GetRetvals(rets);
|
||||
|
@ -104,10 +104,21 @@ Status ValidateMemoryTypes(const DeviceType& device_type, const Graph* g) {
|
||||
});
|
||||
}
|
||||
|
||||
static Node* Send(Graph* g, const string& device_name, bool host,
|
||||
const Edge* edge) {
|
||||
const string tensor_name =
|
||||
strings::StrCat("edge_", edge->id(), "_", edge->src()->name());
|
||||
// Given an Edge whose two endpoints have different memory types and
|
||||
// are gonna to insert a pair of HostSend/Recv or Send/HostRecv nodes,
|
||||
// GetTensorName() returns a unique string that we can use as part of
|
||||
// the rendezvous key. The return string is guaranteed to be unique
|
||||
// within this process. That is sufficient because EnsureMemoryTypes
|
||||
// is only used on a TensorFlow graph that is gonna to be executed in
|
||||
// a single tf device (hence within a single process).
|
||||
static string GetTensorName(const Edge* edge) {
|
||||
static std::atomic<int64> counter(0);
|
||||
return strings::StrCat("memtype_", counter.fetch_add(1), "_",
|
||||
edge->src()->name());
|
||||
}
|
||||
|
||||
static Node* Send(Graph* g, const string& tensor_name,
|
||||
const string& device_name, bool host, const Edge* edge) {
|
||||
Node* ret;
|
||||
TF_CHECK_OK(NodeBuilder(g->NewName("n"), host ? "_HostSend" : "_Send")
|
||||
.Input(edge->src(), edge->src_output())
|
||||
@ -115,14 +126,13 @@ static Node* Send(Graph* g, const string& device_name, bool host,
|
||||
.Attr("send_device", device_name)
|
||||
.Attr("send_device_incarnation", 0) // Do not care.
|
||||
.Attr("recv_device", device_name)
|
||||
.Attr("_hostmem_sendrecv", true)
|
||||
.Finalize(g, &ret));
|
||||
return ret;
|
||||
}
|
||||
|
||||
static Node* Recv(Graph* g, const string& device_name, bool host,
|
||||
const Edge* edge) {
|
||||
const string tensor_name =
|
||||
strings::StrCat("edge_", edge->id(), "_", edge->src()->name());
|
||||
static Node* Recv(Graph* g, const string& tensor_name,
|
||||
const string& device_name, bool host, const Edge* edge) {
|
||||
Node* ret;
|
||||
TF_CHECK_OK(
|
||||
NodeBuilder(g->NewName("n"), host ? "_HostRecv" : "_Recv")
|
||||
@ -131,6 +141,7 @@ static Node* Recv(Graph* g, const string& device_name, bool host,
|
||||
.Attr("send_device", device_name)
|
||||
.Attr("send_device_incarnation", 0)
|
||||
.Attr("recv_device", device_name)
|
||||
.Attr("_hostmem_sendrecv", true)
|
||||
.Finalize(g, &ret));
|
||||
return ret;
|
||||
}
|
||||
@ -171,8 +182,10 @@ Status EnsureMemoryTypes(const DeviceType& device_type,
|
||||
Endpoint key{e->src()->id(), e->src_output()};
|
||||
auto iter = recv_nodes.find(key);
|
||||
if (iter == recv_nodes.end()) {
|
||||
Node* send = Send(g, device_name, (item.sm == HOST_MEMORY), e);
|
||||
recv = Recv(g, device_name, (item.dm == HOST_MEMORY), e);
|
||||
const string tensor_name = GetTensorName(e);
|
||||
Node* send =
|
||||
Send(g, tensor_name, device_name, (item.sm == HOST_MEMORY), e);
|
||||
recv = Recv(g, tensor_name, device_name, (item.dm == HOST_MEMORY), e);
|
||||
if (!has_ref) {
|
||||
// We only cache if there is no ref is involved.
|
||||
recv_nodes[key] = recv;
|
||||
|
@ -37,6 +37,7 @@ class CancellationManager;
|
||||
class GraphDef;
|
||||
class OpKernel;
|
||||
class ResourceMgr;
|
||||
class Rendezvous;
|
||||
class ScopedStepContainer;
|
||||
class StepStatsCollector;
|
||||
class Node;
|
||||
@ -398,11 +399,10 @@ class FunctionLibraryRuntime {
|
||||
//
|
||||
// Does not take ownership of "rets".
|
||||
struct Options {
|
||||
CancellationManager* cancellation_manager = nullptr;
|
||||
// The id of the step that is calling this function.
|
||||
int64 step_id = 0;
|
||||
|
||||
// Per-step container.
|
||||
Rendezvous* rendezvous = nullptr;
|
||||
CancellationManager* cancellation_manager = nullptr;
|
||||
ScopedStepContainer* step_container = nullptr;
|
||||
StepStatsCollector* stats_collector = nullptr;
|
||||
|
||||
|
@ -896,6 +896,36 @@ Status AddControlEdges(const PartitionOptions& opts,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// If 'ndef' is a Send or Recv, fills its attr send_device_incarnation
|
||||
// if possible.
|
||||
void SetIncarnation(const PartitionOptions& opts, NodeDef* ndef) {
|
||||
StringPiece op(ndef->op());
|
||||
if (op != "_Send" && op != "_Recv") {
|
||||
// Not related to send/recv.
|
||||
return;
|
||||
}
|
||||
string send_device;
|
||||
if (!GetNodeAttr(*ndef, "send_device", &send_device).ok()) {
|
||||
// No known send_device. The runtime will detect it later.
|
||||
return;
|
||||
}
|
||||
int64 incarnation = opts.get_incarnation(send_device);
|
||||
AddNodeAttr("send_device_incarnation", incarnation, ndef);
|
||||
}
|
||||
|
||||
// Sets attribute send_device_incarnation of all Send/Recv nodes in
|
||||
// 'gdef', if possible.
|
||||
void SetIncarnation(const PartitionOptions& opts, GraphDef* gdef) {
|
||||
for (NodeDef& ndef : *gdef->mutable_node()) {
|
||||
SetIncarnation(opts, &ndef);
|
||||
}
|
||||
for (FunctionDef& fdef : *gdef->mutable_library()->mutable_function()) {
|
||||
for (NodeDef& ndef : *fdef.mutable_node_def()) {
|
||||
SetIncarnation(opts, &ndef);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Status Partition(const PartitionOptions& opts, Graph* g,
|
||||
std::unordered_map<string, GraphDef>* partitions) {
|
||||
Status status;
|
||||
@ -1130,10 +1160,15 @@ Status Partition(const PartitionOptions& opts, Graph* g,
|
||||
}
|
||||
}
|
||||
|
||||
// Set versions and function library
|
||||
// Set versions, function library and send/recv incarnation.
|
||||
for (auto& it : *partitions) {
|
||||
it.second.mutable_versions()->CopyFrom(g->versions());
|
||||
*it.second.mutable_library() = g->flib_def().ToProto();
|
||||
GraphDef* gdef = &it.second;
|
||||
*gdef->mutable_versions() = g->versions();
|
||||
*gdef->mutable_library() = g->flib_def().ToProto();
|
||||
|
||||
// Traverse the graph to fill every send/recv op's incarnation
|
||||
// information.
|
||||
SetIncarnation(opts, gdef);
|
||||
}
|
||||
|
||||
// Set the start times for recvs at the very end.
|
||||
|
@ -114,7 +114,9 @@ cc_test(
|
||||
deps = [
|
||||
":single_machine",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:resource_variable_ops",
|
||||
"//tensorflow/cc:scope",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
|
@ -73,8 +73,6 @@ SingleMachine::~SingleMachine() {
|
||||
// when we delete the session.
|
||||
thread_pool_.reset();
|
||||
|
||||
Reset(options_, {}).IgnoreError();
|
||||
|
||||
CHECK(already_created);
|
||||
already_created = false;
|
||||
}
|
||||
@ -277,11 +275,9 @@ Status SingleMachine::ResetSession() {
|
||||
// Make sure the session is properly closed
|
||||
TF_RETURN_IF_ERROR(Shutdown());
|
||||
|
||||
// We need to Reset the session to ensure that all the variables are
|
||||
// deleted. But first we need to delete the session since Reset()
|
||||
// deletes some of the containers referenced by the session.
|
||||
// Destroying the object deletes all its varibles as well. This is only true
|
||||
// for DirectSession.
|
||||
session_.reset();
|
||||
TF_RETURN_IF_ERROR(Reset(options_, {}));
|
||||
}
|
||||
|
||||
LOG(INFO) << "Starting new session";
|
||||
|
@ -15,7 +15,10 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/grappler/clusters/single_machine.h"
|
||||
#include "tensorflow/cc/framework/scope.h"
|
||||
#include "tensorflow/cc/ops/resource_variable_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/framework/cost_graph.pb.h"
|
||||
#include "tensorflow/core/framework/step_stats.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
@ -24,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/protobuf/queue_runner.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
@ -349,6 +353,7 @@ TEST_F(SingleMachineTest, InitializationMemory) {
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
template <class T>
|
||||
inline void SetNodeAttr(const string& key, const T& value, NodeDef* node) {
|
||||
AttrValue attr_value;
|
||||
@ -463,6 +468,124 @@ TEST_F(SingleMachineTest, PersistentMemory) {
|
||||
EXPECT_TRUE(found_hashtable);
|
||||
}
|
||||
|
||||
#if defined(PLATFORM_GOOGLE)
|
||||
namespace {
|
||||
|
||||
SessionOptions GetSessionOption(int num_cpu_cores, int num_gpus) {
|
||||
SessionOptions options;
|
||||
// Copied from single_machine.h
|
||||
(*options.config.mutable_device_count())["CPU"] = 1;
|
||||
if (num_gpus > 0) {
|
||||
(*options.config.mutable_device_count())["GPU"] = num_gpus;
|
||||
}
|
||||
CHECK_GE(num_cpu_cores, 1);
|
||||
options.config.set_intra_op_parallelism_threads(num_cpu_cores);
|
||||
options.config.add_session_inter_op_thread_pool()->set_num_threads(
|
||||
num_cpu_cores);
|
||||
return options;
|
||||
}
|
||||
|
||||
Status GetDeviceMemoryStats(
|
||||
const SessionOptions& session_option,
|
||||
std::unordered_map<string, AllocatorStats>* allocator_stats_by_device) {
|
||||
std::vector<Device*> devices;
|
||||
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(session_option,
|
||||
"" /* name_prefix */, &devices));
|
||||
allocator_stats_by_device->clear();
|
||||
for (Device* device : devices) {
|
||||
AllocatorStats stats;
|
||||
auto* allocator = device->GetAllocator(AllocatorAttributes());
|
||||
if (!allocator->TracksAllocationSizes()) {
|
||||
return Status(error::INVALID_ARGUMENT,
|
||||
"Tracking allocation is not enabled.");
|
||||
}
|
||||
allocator->GetStats(&stats);
|
||||
(*allocator_stats_by_device)[device->name()] = stats;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TEST_F(SingleMachineTest, ReleaseMemoryAfterDestruction) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
|
||||
// Add a variable and initializer.
|
||||
Output a = ops::Variable(s.WithOpName("a"), TensorShape({128, 256}),
|
||||
DataType::DT_FLOAT);
|
||||
Output a_init =
|
||||
ops::RandomNormal(s.WithOpName("a/init"), {128, 256}, DataType::DT_FLOAT);
|
||||
Output a_init_assign = ops::Assign(s.WithOpName("a/init/assign"), a, a_init);
|
||||
|
||||
// Add a resource variable.
|
||||
Output b =
|
||||
ops::VarHandleOp(s.WithOpName("b"), DataType::DT_FLOAT, {256, 512});
|
||||
Output b_read =
|
||||
ops::ReadVariableOp(s.WithOpName("b/read"), b, DataType::DT_FLOAT);
|
||||
Output b_init =
|
||||
ops::RandomNormal(s.WithOpName("b/init"), {256, 512}, DataType::DT_FLOAT);
|
||||
auto b_init_assign =
|
||||
ops::AssignVariableOp(s.WithOpName("b/init/assign"), b, b_init);
|
||||
|
||||
// Add a queue.
|
||||
ops::FIFOQueue queue(s.WithOpName("queue"), {DataType::DT_STRING});
|
||||
Output some_string =
|
||||
ops::Const(s.WithOpName("some_string"), string("nothing"));
|
||||
ops::QueueEnqueue enqueue(s.WithOpName("enqueue"), queue, {some_string});
|
||||
ops::QueueDequeue dequeue(s.WithOpName("dequeue"), queue,
|
||||
{DataType::DT_STRING});
|
||||
|
||||
// Add a IdentityReader.
|
||||
ops::IdentityReader reader(s.WithOpName("identity_reader"));
|
||||
ops::ReaderRead read(s.WithOpName("read_from_queue"), reader, queue);
|
||||
|
||||
Output var_mul = ops::MatMul(s.WithOpName("var_matmul"), a, b_read);
|
||||
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
QueueRunnerDef queue_runner;
|
||||
queue_runner.set_queue_name("queue");
|
||||
*queue_runner.add_enqueue_op_name() = "enqueue";
|
||||
item.queue_runners.push_back(queue_runner);
|
||||
|
||||
item.init_ops.push_back("a/init/assign");
|
||||
item.init_ops.push_back("b/init/assign");
|
||||
item.fetch.push_back("var_matmul");
|
||||
item.fetch.push_back("dequeue");
|
||||
|
||||
// Run the graph
|
||||
TF_CHECK_OK(cluster_->Initialize(item));
|
||||
EnableCPUAllocatorStats(true);
|
||||
|
||||
SessionOptions options =
|
||||
GetSessionOption(3 /* cpu cores */, 0 /* num gpus */);
|
||||
std::unordered_map<string, AllocatorStats> device_memory_before;
|
||||
TF_CHECK_OK(GetDeviceMemoryStats(options, &device_memory_before));
|
||||
EXPECT_EQ(device_memory_before.size(), 1);
|
||||
|
||||
RunMetadata metadata;
|
||||
TF_CHECK_OK(cluster_->Run(item.graph, item.feed, item.fetch, &metadata));
|
||||
|
||||
// Check there is memory that is not released.
|
||||
std::unordered_map<string, AllocatorStats> device_memory;
|
||||
TF_CHECK_OK(GetDeviceMemoryStats(options, &device_memory));
|
||||
EXPECT_EQ(device_memory.size(), 1);
|
||||
EXPECT_GT(device_memory.begin()->second.bytes_in_use, 0);
|
||||
|
||||
// Reset cluster_ would release all memory.
|
||||
cluster_.reset();
|
||||
std::unordered_map<string, AllocatorStats> device_memory_after;
|
||||
TF_CHECK_OK(GetDeviceMemoryStats(options, &device_memory_after));
|
||||
|
||||
// Check memory used by resources are released after cluster destruction.
|
||||
EXPECT_EQ(device_memory_before.size(), 1);
|
||||
EXPECT_EQ(device_memory_after.size(), 1);
|
||||
EXPECT_EQ(device_memory_before.begin()->second.bytes_in_use, 0);
|
||||
EXPECT_EQ(device_memory_after.begin()->second.bytes_in_use, 0);
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
@ -82,10 +82,6 @@ Status OptimizeGraph(const GraphDef& graph_def, GraphDef* output_graph_def,
|
||||
|
||||
// Inline all functions.
|
||||
GraphDef inlined_graph_def(graph_def);
|
||||
// Populate default attrs to the NodeDefs in the GraphDef, which is required
|
||||
// by inlining code.
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddDefaultAttrsToGraphDef(&inlined_graph_def, *OpRegistry::Global(), 0));
|
||||
|
||||
for (int i = 0; i < inlined_graph_def.library().function().size(); i++) {
|
||||
FunctionDef* fdef =
|
||||
@ -122,6 +118,10 @@ Status OptimizeGraph(const GraphDef& graph_def, GraphDef* output_graph_def,
|
||||
graph_ctor_opts.allow_internal_ops = true;
|
||||
graph_ctor_opts.expect_device_spec = false;
|
||||
std::unique_ptr<Graph> graphptr(new Graph(function_library));
|
||||
// Populate default attrs to the NodeDefs in the GraphDef.
|
||||
TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&inlined_graph_def,
|
||||
*graphptr->op_registry(), 0));
|
||||
|
||||
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(graph_ctor_opts, inlined_graph_def,
|
||||
graphptr.get()));
|
||||
|
||||
|
@ -48,9 +48,17 @@ GraphDef CreateGraphDef(int num_stages, int width, int tensor_size,
|
||||
for (int i = 0; i < num_stages; i++) {
|
||||
std::vector<Output> this_stage;
|
||||
for (int j = 0; j < width; j++) {
|
||||
Output combine = AddN(
|
||||
s.WithDevice(device_names[use_multiple_devices ? j : 0]), last_stage);
|
||||
this_stage.push_back(combine);
|
||||
if (last_stage.size() == 1) {
|
||||
Output unary_op =
|
||||
Square(s.WithDevice(device_names[use_multiple_devices ? j : 0]),
|
||||
last_stage[0]);
|
||||
this_stage.push_back(unary_op);
|
||||
} else {
|
||||
Output combine =
|
||||
AddN(s.WithDevice(device_names[use_multiple_devices ? j : 0]),
|
||||
last_stage);
|
||||
this_stage.push_back(combine);
|
||||
}
|
||||
}
|
||||
last_stage = this_stage;
|
||||
}
|
||||
|
@ -18,6 +18,11 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
bool IsAddN(const NodeDef& node) {
|
||||
const auto op = node.op();
|
||||
return op == "AddN";
|
||||
}
|
||||
|
||||
bool IsConcat(const NodeDef& node) {
|
||||
const auto op = node.op();
|
||||
return op == "Concat" || op == "ConcatV2";
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
bool IsAddN(const NodeDef& node);
|
||||
bool IsConcat(const NodeDef& node);
|
||||
bool IsConstant(const NodeDef& node);
|
||||
bool IsDequeueOp(const NodeDef& node);
|
||||
|
@ -26,6 +26,29 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
int NumNonControlInputs(const NodeDef& node) {
|
||||
int num_inputs = node.input_size();
|
||||
for (int i = 0; i < node.input_size(); ++i) {
|
||||
if (!node.input(i).empty() && node.input(i)[0] == '^') {
|
||||
num_inputs--;
|
||||
}
|
||||
}
|
||||
return num_inputs;
|
||||
}
|
||||
|
||||
bool IsTrivialOp(const NodeDef& node) {
|
||||
// Remove the stop gradient nodes since they serve no purpose once the graph
|
||||
// is built. Also remove Identity ops.
|
||||
if (IsStopGradient(node) || IsIdentity(node)) {
|
||||
return true;
|
||||
}
|
||||
if (IsAddN(node) && NumNonControlInputs(node) <= 1) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
GraphDef* pruned_graph) {
|
||||
GraphRewriter rewriter(item);
|
||||
@ -43,9 +66,7 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
|
||||
std::unordered_set<const NodeDef*> nodes_to_delete;
|
||||
for (auto& node : item.graph.node()) {
|
||||
// Remove the stop gradient nodes since they serve no purpose once the graph
|
||||
// is built. Also remove Identity ops.
|
||||
if (!IsStopGradient(node) && !IsIdentity(node)) {
|
||||
if (!IsTrivialOp(node)) {
|
||||
continue;
|
||||
}
|
||||
// Don't remove nodes that must be preserved.
|
||||
|
@ -57,10 +57,10 @@ TEST_F(ModelPrunerTest, StopGradientPruning) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
|
||||
Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
|
||||
Output b = ops::AddN(s.WithOpName("b"), {a});
|
||||
Output b = ops::Sqrt(s.WithOpName("b"), {a});
|
||||
Output c = ops::StopGradient(s.WithOpName("c"), b);
|
||||
Output d = ops::StopGradient(s.WithOpName("d"), c);
|
||||
Output e = ops::AddN(s.WithOpName("e"), {d});
|
||||
Output e = ops::Sqrt(s.WithOpName("e"), {d});
|
||||
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
@ -93,10 +93,10 @@ TEST_F(ModelPrunerTest, IdentityPruning) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
|
||||
Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
|
||||
Output b = ops::AddN(s.WithOpName("b"), {a});
|
||||
Output b = ops::Sqrt(s.WithOpName("b"), {a});
|
||||
Output c = ops::Identity(s.WithOpName("c"), b);
|
||||
Output d = ops::Identity(s.WithOpName("d"), c);
|
||||
Output e = ops::AddN(s.WithOpName("e"), {d});
|
||||
Output e = ops::Sqrt(s.WithOpName("e"), {d});
|
||||
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
@ -126,15 +126,53 @@ TEST_F(ModelPrunerTest, IdentityPruning) {
|
||||
EXPECT_EQ(NodeName(b.name()), new_c.input(0));
|
||||
}
|
||||
|
||||
TEST_F(ModelPrunerTest, PruningSkipsCtrlDependencies) {
|
||||
TEST_F(ModelPrunerTest, NoOpPruning) {
|
||||
// Build a simple graph with a few trivially prunable ops.
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
|
||||
Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
|
||||
Output b = ops::AddN(s.WithOpName("b"), {a});
|
||||
Output c = ops::AddN(s.WithOpName("c"), {b});
|
||||
Output d = ops::AddN(s.WithOpName("d").WithControlDependencies(b), {c});
|
||||
Output e = ops::AddN(s.WithOpName("e"), {d});
|
||||
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
ModelPruner pruner;
|
||||
GraphDef output;
|
||||
Status status = pruner.Optimize(nullptr, item, &output);
|
||||
TF_EXPECT_OK(status);
|
||||
|
||||
EXPECT_EQ(5, output.node_size());
|
||||
const NodeDef& new_a = output.node(0);
|
||||
EXPECT_EQ(NodeName(a.name()), new_a.name());
|
||||
const NodeDef& new_b = output.node(1);
|
||||
EXPECT_EQ(NodeName(b.name()), new_b.name());
|
||||
const NodeDef& new_c = output.node(2);
|
||||
EXPECT_EQ(NodeName(c.name()), new_c.name());
|
||||
const NodeDef& new_d = output.node(3);
|
||||
EXPECT_EQ(NodeName(d.name()), new_d.name());
|
||||
const NodeDef& new_e = output.node(4);
|
||||
EXPECT_EQ(NodeName(e.name()), new_e.name());
|
||||
|
||||
EXPECT_EQ(1, new_e.input_size());
|
||||
EXPECT_EQ(NodeName(d.name()), new_e.input(0));
|
||||
EXPECT_EQ(2, new_d.input_size());
|
||||
EXPECT_EQ(NodeName(b.name()), new_d.input(0));
|
||||
EXPECT_EQ(1, new_c.input_size());
|
||||
EXPECT_EQ(NodeName(b.name()), new_c.input(0));
|
||||
}
|
||||
|
||||
TEST_F(ModelPrunerTest, PruningSkipsCtrlDependencies) {
|
||||
// Build a simple graph with a few trivially prunable ops.
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
|
||||
Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
|
||||
Output b = ops::Sqrt(s.WithOpName("b"), {a});
|
||||
Output c = ops::Identity(s.WithOpName("c"), b);
|
||||
Output d = ops::Identity(s.WithOpName("d"), c);
|
||||
Output e = ops::AddN(s.WithOpName("e").WithControlDependencies(c), {d});
|
||||
Output e = ops::Sqrt(s.WithOpName("e").WithControlDependencies(c), {d});
|
||||
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
@ -166,11 +204,11 @@ TEST_F(ModelPrunerTest, PruningPerservesCtrlDependencies) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
|
||||
Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
|
||||
Output b = ops::AddN(s.WithOpName("b"), {a});
|
||||
Output c = ops::AddN(s.WithOpName("c"), {a});
|
||||
Output b = ops::Sqrt(s.WithOpName("b"), {a});
|
||||
Output c = ops::Sqrt(s.WithOpName("c"), {a});
|
||||
Output d = ops::Identity(s.WithOpName("d"), c);
|
||||
Output e = ops::Identity(s.WithOpName("e"), d);
|
||||
Output f = ops::AddN(s.WithOpName("f"), {e});
|
||||
Output f = ops::Sqrt(s.WithOpName("f"), {e});
|
||||
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
@ -216,7 +254,7 @@ TEST_F(ModelPrunerTest, PruningPerservesFetch) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
|
||||
Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
|
||||
Output b = ops::AddN(s.WithOpName("b"), {a});
|
||||
Output b = ops::Sqrt(s.WithOpName("b"), {a});
|
||||
Output c = ops::Identity(s.WithOpName("c"), b);
|
||||
|
||||
GrapplerItem item;
|
||||
@ -243,13 +281,13 @@ TEST_F(ModelPrunerTest, PruningPerservesCrossDeviceIdentity) {
|
||||
|
||||
// Node i1 should be preserved.
|
||||
Output i1 = ops::Identity(s.WithOpName("i1").WithDevice("/gpu:0"), c);
|
||||
Output a1 = ops::AddN(s.WithOpName("a1").WithDevice("/gpu:0"), {i1});
|
||||
Output a2 = ops::AddN(s.WithOpName("a2").WithDevice("/gpu:0"), {i1});
|
||||
Output a1 = ops::Sqrt(s.WithOpName("a1").WithDevice("/gpu:0"), {i1});
|
||||
Output a2 = ops::Sqrt(s.WithOpName("a2").WithDevice("/gpu:0"), {i1});
|
||||
|
||||
// Node i2 should be pruned since it resides on the sender's device.
|
||||
Output i2 = ops::Identity(s.WithOpName("i2").WithDevice("/cpu:0"), c);
|
||||
Output a3 = ops::AddN(s.WithOpName("a3").WithDevice("/gpu:0"), {i2});
|
||||
Output a4 = ops::AddN(s.WithOpName("a4").WithDevice("/gpu:0"), {i2});
|
||||
Output a3 = ops::Sqrt(s.WithOpName("a3").WithDevice("/gpu:0"), {i2});
|
||||
Output a4 = ops::Sqrt(s.WithOpName("a4").WithDevice("/gpu:0"), {i2});
|
||||
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
|
@ -237,6 +237,8 @@ class SymbolicGradientOp : public AsyncOpKernel {
|
||||
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.step_id = ctx->step_id();
|
||||
opts.rendezvous = ctx->rendezvous();
|
||||
opts.cancellation_manager = ctx->cancellation_manager();
|
||||
opts.runner = ctx->runner();
|
||||
std::vector<Tensor> args;
|
||||
args.reserve(ctx->num_inputs());
|
||||
|
@ -124,6 +124,20 @@ class MutableHashTableOfScalars final : public LookupInterface {
|
||||
|
||||
TensorShape value_shape() const override { return TensorShape(); }
|
||||
|
||||
int64 MemoryUsed() const override {
|
||||
int64 ret = 0;
|
||||
mutex_lock l(mu_);
|
||||
for (unsigned i = 0; i < table_.bucket_count(); ++i) {
|
||||
size_t bucket_size = table_.bucket_size(i);
|
||||
if (bucket_size == 0) {
|
||||
ret++;
|
||||
} else {
|
||||
ret += bucket_size;
|
||||
}
|
||||
}
|
||||
return sizeof(MutableHashTableOfScalars) + ret;
|
||||
}
|
||||
|
||||
private:
|
||||
// TODO(andreasst): consider using a read/write lock or a concurrent map
|
||||
mutable mutex mu_;
|
||||
@ -239,6 +253,20 @@ class MutableHashTableOfTensors final : public LookupInterface {
|
||||
|
||||
TensorShape value_shape() const override { return value_shape_; }
|
||||
|
||||
int64 MemoryUsed() const override {
|
||||
int64 ret = 0;
|
||||
mutex_lock l(mu_);
|
||||
for (unsigned i = 0; i < table_.bucket_count(); ++i) {
|
||||
size_t bucket_size = table_.bucket_size(i);
|
||||
if (bucket_size == 0) {
|
||||
ret++;
|
||||
} else {
|
||||
ret += bucket_size;
|
||||
}
|
||||
}
|
||||
return sizeof(MutableHashTableOfTensors) + ret;
|
||||
}
|
||||
|
||||
private:
|
||||
TensorShape value_shape_;
|
||||
// TODO(andreasst): consider using a read/write lock or a concurrent map
|
||||
@ -467,6 +495,12 @@ class MutableDenseHashTable final : public LookupInterface {
|
||||
|
||||
TensorShape value_shape() const override { return value_shape_; }
|
||||
|
||||
int64 MemoryUsed() const override {
|
||||
mutex_lock l(mu_);
|
||||
return sizeof(MutableDenseHashTable) + key_buckets_.AllocatedBytes() +
|
||||
value_buckets_.AllocatedBytes() + empty_key_.AllocatedBytes();
|
||||
}
|
||||
|
||||
private:
|
||||
Status DoInsert(OpKernelContext* ctx, const Tensor& key, const Tensor& value,
|
||||
bool ignore_empty_key) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
|
@ -39,6 +39,19 @@ static void GetRendezvousKey(const string& key_prefix,
|
||||
frame_iter.iter_id);
|
||||
}
|
||||
|
||||
static FrameAndIter GetFrameAndIter(OpKernelContext* ctx,
|
||||
bool hostmem_sendrecv) {
|
||||
if (hostmem_sendrecv && ctx->call_frame() != nullptr) {
|
||||
// Host memory send/recv pairs are added by
|
||||
// common_runtime/memory_types.cc. When the pair of nodes are
|
||||
// added inside a function, we need to use the function call frame
|
||||
// to formulate the unique rendezvous key.
|
||||
return FrameAndIter(reinterpret_cast<uint64>(ctx->call_frame()), 0);
|
||||
} else {
|
||||
return ctx->frame_iter();
|
||||
}
|
||||
}
|
||||
|
||||
SendOp::SendOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
string send_device;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("send_device", &send_device));
|
||||
@ -56,6 +69,9 @@ SendOp::SendOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
// proactively cache the rendezvous key for the top-level.
|
||||
GetRendezvousKey(key_prefix_, {0, 0}, &parsed_key_.buf_);
|
||||
OP_REQUIRES_OK(ctx, Rendezvous::ParseKey(parsed_key_.buf_, &parsed_key_));
|
||||
if (!ctx->GetAttr("_hostmem_sendrecv", &hostmem_sendrecv_).ok()) {
|
||||
hostmem_sendrecv_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
void SendOp::Compute(OpKernelContext* ctx) {
|
||||
@ -71,7 +87,8 @@ void SendOp::Compute(OpKernelContext* ctx) {
|
||||
args.device_context = ctx->op_device_context();
|
||||
args.alloc_attrs = ctx->input_alloc_attr(0);
|
||||
|
||||
if (ctx->frame_iter() == FrameAndIter(0, 0)) {
|
||||
FrameAndIter frame_iter = GetFrameAndIter(ctx, hostmem_sendrecv_);
|
||||
if (frame_iter == FrameAndIter(0, 0)) {
|
||||
// Use the cached rendezvous key.
|
||||
VLOG(2) << "Send " << parsed_key_.buf_;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
@ -79,7 +96,7 @@ void SendOp::Compute(OpKernelContext* ctx) {
|
||||
ctx->is_input_dead()));
|
||||
} else {
|
||||
Rendezvous::ParsedKey in_loop_parsed;
|
||||
GetRendezvousKey(key_prefix_, ctx->frame_iter(), &in_loop_parsed.buf_);
|
||||
GetRendezvousKey(key_prefix_, frame_iter, &in_loop_parsed.buf_);
|
||||
VLOG(2) << "Send " << in_loop_parsed.buf_;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
Rendezvous::ParseKey(in_loop_parsed.buf_, &in_loop_parsed));
|
||||
@ -120,6 +137,9 @@ RecvOp::RecvOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
|
||||
// proactively cache the rendezvous key for the top-level.
|
||||
GetRendezvousKey(key_prefix_, {0, 0}, &parsed_key_.buf_);
|
||||
OP_REQUIRES_OK(ctx, Rendezvous::ParseKey(parsed_key_.buf_, &parsed_key_));
|
||||
if (!ctx->GetAttr("_hostmem_sendrecv", &hostmem_sendrecv_).ok()) {
|
||||
hostmem_sendrecv_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
void RecvOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
|
||||
@ -151,12 +171,13 @@ void RecvOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
|
||||
},
|
||||
std::move(done), _1, _2, _3, _4, _5);
|
||||
|
||||
if (ctx->frame_iter() == FrameAndIter(0, 0)) {
|
||||
FrameAndIter frame_iter = GetFrameAndIter(ctx, hostmem_sendrecv_);
|
||||
if (frame_iter == FrameAndIter(0, 0)) {
|
||||
VLOG(2) << "Recv " << parsed_key_.buf_;
|
||||
ctx->rendezvous()->RecvAsync(parsed_key_, args, std::move(done_cb));
|
||||
} else {
|
||||
Rendezvous::ParsedKey in_loop_parsed;
|
||||
GetRendezvousKey(key_prefix_, ctx->frame_iter(), &in_loop_parsed.buf_);
|
||||
GetRendezvousKey(key_prefix_, frame_iter, &in_loop_parsed.buf_);
|
||||
VLOG(2) << "Recv " << in_loop_parsed.buf_;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx, Rendezvous::ParseKey(in_loop_parsed.buf_, &in_loop_parsed), done);
|
||||
|
@ -29,6 +29,7 @@ class SendOp : public OpKernel {
|
||||
private:
|
||||
string key_prefix_;
|
||||
Rendezvous::ParsedKey parsed_key_;
|
||||
bool hostmem_sendrecv_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(SendOp);
|
||||
};
|
||||
@ -41,6 +42,7 @@ class RecvOp : public AsyncOpKernel {
|
||||
private:
|
||||
string key_prefix_;
|
||||
Rendezvous::ParsedKey parsed_key_;
|
||||
bool hostmem_sendrecv_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(RecvOp);
|
||||
};
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
|
||||
namespace tensorflow {
|
||||
class DeviceMgr;
|
||||
|
||||
/// \brief A Session instance lets a caller drive a TensorFlow graph
|
||||
/// computation.
|
||||
@ -177,12 +178,24 @@ class Session {
|
||||
/// *response. This API is optional. If it is unimplemented, Status will
|
||||
/// return a corresponding error message, and *response will be unmodified.
|
||||
virtual Status ListDevices(std::vector<DeviceAttributes>* response) = 0;
|
||||
|
||||
/// \brief Closes this session.
|
||||
///
|
||||
/// Closing a session releases the resources used by this session
|
||||
/// on the TensorFlow runtime (specified during session creation by
|
||||
/// the `SessionOptions::target` field).
|
||||
virtual Status Close() = 0;
|
||||
|
||||
// NOTE(ashankar): As of July 2017, this is was a method added to
|
||||
// faciliate some experimentation. Reconsider/re-evaluate after
|
||||
// September 2017.
|
||||
//
|
||||
// Sets `*output` to the `DeviceMgr` that owns accessible devices in the
|
||||
// address-space of the caller.
|
||||
virtual Status LocalDeviceManager(const DeviceMgr** output) {
|
||||
return errors::Unimplemented(
|
||||
"LocalDeviceManager is not supported for this session.");
|
||||
}
|
||||
};
|
||||
|
||||
/// \brief Create a new session with the given options.
|
||||
|
153
tensorflow/docs_src/programmers_guide/estimators.md
Normal file
153
tensorflow/docs_src/programmers_guide/estimators.md
Normal file
@ -0,0 +1,153 @@
|
||||
# Estimators
|
||||
|
||||
This document introduces **Estimators**--a high-level TensorFlow API that
|
||||
greatly simplifies machine learning programming. Estimators encapsulate
|
||||
the following actions:
|
||||
|
||||
* training
|
||||
* evaluation
|
||||
* prediction
|
||||
* export for serving
|
||||
|
||||
You may either use the pre-made Estimators we provide or write your
|
||||
own custom Estimators. All Estimators--whether pre-made or custom--are
|
||||
classes based on the `tf.estimator.Estimator` class.
|
||||
|
||||
Note: TensorFlow also provides an Estimator class at
|
||||
`tf.contrib.learn.Estimator`, which you should not use.</aside>
|
||||
|
||||
|
||||
## Advantages of Estimators
|
||||
|
||||
Estimators provide the following benefits:
|
||||
|
||||
* You can run Estimators-based models on a local host or on a
|
||||
distributed multi-server environment without changing your model.
|
||||
Furthermore, you can run Estimators-based models on CPUs, GPUs,
|
||||
or TPUs without recoding your model.
|
||||
* Estimators simplify sharing implementations between model developers.
|
||||
* You can develop a state of the art model with high-level intuitive code,
|
||||
In short, it is generally much easier to create models with Estimators
|
||||
than with the low-level TensorFlow APIs.
|
||||
* Estimators are themselves built on tf.layers, which
|
||||
simplifies customization.
|
||||
* Estimators build the graph for you. In other words, you don't have to
|
||||
build the graph.
|
||||
* Estimators provide a safe distributed training loop that controls how and
|
||||
when to:
|
||||
* build the graph
|
||||
* initialize variables
|
||||
* start queues
|
||||
* handle exceptions
|
||||
* create checkpoint files and recover from failures
|
||||
* save summaries for TensorBoard
|
||||
|
||||
When writing an application with Estimators, you must separate the data input
|
||||
pipeline from the model. This separation simplifies experiments with
|
||||
different data sets.
|
||||
|
||||
|
||||
## Pre-made Estimators
|
||||
|
||||
Pre-made Estimators enable you to work at a much higher conceptual level
|
||||
than the base TensorFlow APIs. You no longer have to worry about creating
|
||||
the computational graph or sessions since Estimators handle all
|
||||
the "plumbing" for you. That is, pre-made Estimators create and manage
|
||||
`Graph` and `Session` objects for you. Furthermore, pre-made Estimators
|
||||
let you experiment with different model architectures by making only minimal
|
||||
code changes. `DNNClassifier`, for example, is a pre-made Estimator class that
|
||||
trains classification models through dense, feed-forward neural networks.
|
||||
|
||||
|
||||
### Structure of a pre-made Estimators program
|
||||
|
||||
A TensorFlow program relying on a pre-made Estimator typically consists
|
||||
of the following four steps:
|
||||
|
||||
1. **Write one or more dataset importing functions.** For example, you might
|
||||
create one function to import the training set and another function to
|
||||
import the test set. Each dataset importing function must return two
|
||||
objects:
|
||||
|
||||
* a dictionary in which the keys are feature column names and the
|
||||
values are Tensors (or SparseTensors) containing the corresponding
|
||||
feature data
|
||||
* a Tensor containing one or more labels
|
||||
|
||||
For example, the following code illustrates the basic skeleton for
|
||||
an input function:
|
||||
|
||||
def input_fn(dataset):
|
||||
... # manipulate dataset, extracting feature names and the label
|
||||
return feature_dict, label
|
||||
|
||||
See @{$datasets$Using the `Dataset` API for TensorFlow Input Pipelines}
|
||||
for full details.)
|
||||
|
||||
2. **Define the feature columns.** Each @{tf.feature_column}
|
||||
identifies a feature name, its type, and any input pre-processing.
|
||||
For example, the following snippet creates three feature
|
||||
columns that hold integer or floating-point data. The first two
|
||||
feature columns simply identify the feature's name and type. The
|
||||
third feature column also specifies a lambda the program will invoke
|
||||
to scale the raw data:
|
||||
|
||||
# Define three numeric feature columns.
|
||||
population = tf.feature_column.numeric_column('population')
|
||||
crime_rate = tf.feature_column.numeric_column('crime_rate')
|
||||
median_education = tf.feature_column.numeric_column('median_education',
|
||||
normalizer_fn='lambda x: x - global_education_mean')
|
||||
|
||||
3. **Instantiate the relevant pre-made Estimator.** For example, here's
|
||||
a sample instantiation of a pre-made Estimator named `LinearClassifier`:
|
||||
|
||||
# Instantiate an estimator, passing the feature columns.
|
||||
estimator = tf.estimator.Estimator.LinearClassifier(
|
||||
feature_columns=[population, crime_rate, median_education],
|
||||
)
|
||||
|
||||
4. **Call a training, evaluation, or inference method.**
|
||||
For example, all Estimators provide a `train` method, which trains a model.
|
||||
|
||||
# my_training_set is the function created in Step 1
|
||||
estimator.train(input_fn=my_training_set, steps=2000)
|
||||
|
||||
|
||||
### Benefits of pre-made Estimators
|
||||
|
||||
Pre-made Estimators encode best practices, providing the following benefits:
|
||||
|
||||
* Best practices for determining where different parts of the computational
|
||||
graph should run, implementing strategies on a single machine or on a
|
||||
cluster.
|
||||
* Best practices for event (summary) writing and universally useful
|
||||
summaries.
|
||||
|
||||
If you don't use pre-made Estimators, you must implement the preceding
|
||||
features yourself.
|
||||
|
||||
|
||||
## Custom Estimators
|
||||
|
||||
The heart of every Estimator--whether pre-made or custom--is its
|
||||
**model function**, which is a method that builds graphs for training,
|
||||
evaluation, and prediction. When you are using a pre-made Estimator,
|
||||
someone else has already implemented the model function. When relying
|
||||
on a custom Estimator, you must write the model function yourself. A
|
||||
${$extend/estimators$companion document)
|
||||
explains how to write the model function.
|
||||
|
||||
|
||||
## Recommended workflow
|
||||
|
||||
We recommend the following workflow:
|
||||
|
||||
1. Assuming a suitable pre-made Estimator exists, use it to build your
|
||||
first model and use its results to establish a baseline.
|
||||
2. Build and test your overall pipeline, including the integrity and
|
||||
reliability of your data with this pre-made Estimator.
|
||||
3. If suitable alternative pre-made Estimators are available, run
|
||||
experiments to determine which pre-made Estimator produces the
|
||||
best results.
|
||||
4. Possibly, further improve your model by building your own custom Estimator.
|
||||
|
@ -325,6 +325,25 @@ class FunctionTest(test.TestCase):
|
||||
"assertion"):
|
||||
_ = MyFn(100.0).eval()
|
||||
|
||||
def testWhileLoopCallsFunc(self):
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
|
||||
@function.Defun(dtypes.float32)
|
||||
def Times2(x):
|
||||
constant_two = constant_op.constant(2, dtypes.int32)
|
||||
two_on_gpu = math_ops.cast(constant_two, dtypes.float32)
|
||||
return x * two_on_gpu
|
||||
|
||||
def Body(x):
|
||||
x2 = Times2(x)
|
||||
x2.set_shape([])
|
||||
return x2
|
||||
|
||||
loop = control_flow_ops.while_loop(lambda x: x < 1e5, Body, [1.0])
|
||||
|
||||
ans = sess.run(loop)
|
||||
self.assertAllClose(ans, 131072.)
|
||||
|
||||
def testControlFlowStrictness(self):
|
||||
"""Inlined functions must not execute in a untaken control flow branch."""
|
||||
|
||||
@ -588,8 +607,8 @@ class FunctionTest(test.TestCase):
|
||||
self.assertAllClose(vals[2], vals[3])
|
||||
|
||||
def testDeclare(self):
|
||||
foo = function.Declare("Foo", [("x", dtypes.float32)],
|
||||
[("y", dtypes.float32)])
|
||||
foo = function.Declare("Foo", [("x", dtypes.float32)], [("y",
|
||||
dtypes.float32)])
|
||||
|
||||
@function.Defun(dtypes.float32, func_name="Foo", out_names=["y"])
|
||||
def FooImpl(x):
|
||||
@ -607,8 +626,8 @@ class FunctionTest(test.TestCase):
|
||||
self.assertAllClose(expected, y.eval(feed_dict={x: rand}))
|
||||
|
||||
def testDeclareUsedInDefun(self):
|
||||
foo = function.Declare("Foo", [("x", dtypes.float32)],
|
||||
[("y", dtypes.float32)])
|
||||
foo = function.Declare("Foo", [("x", dtypes.float32)], [("y",
|
||||
dtypes.float32)])
|
||||
|
||||
@function.Defun()
|
||||
def Bar(x):
|
||||
@ -630,8 +649,8 @@ class FunctionTest(test.TestCase):
|
||||
self.assertAllClose(expected, y.eval(feed_dict={x: rand}))
|
||||
|
||||
def testDeclareTypeMistake(self):
|
||||
foo = function.Declare("Foo", [("x", dtypes.float32)],
|
||||
[("y", dtypes.float32)])
|
||||
foo = function.Declare("Foo", [("x", dtypes.float32)], [("y",
|
||||
dtypes.float32)])
|
||||
|
||||
@function.Defun(dtypes.float32, func_name="Foo", out_names=["y"])
|
||||
def Foo(x):
|
||||
@ -749,8 +768,9 @@ class FunctionTest(test.TestCase):
|
||||
self.assertAllEqual(v1, 20.)
|
||||
|
||||
def testShapeFunction(self):
|
||||
@function.Defun(dtypes.float32,
|
||||
shape_func=lambda op: [op.inputs[0].get_shape()])
|
||||
|
||||
@function.Defun(
|
||||
dtypes.float32, shape_func=lambda op: [op.inputs[0].get_shape()])
|
||||
def Foo(x):
|
||||
return x + 1.0
|
||||
|
||||
@ -767,11 +787,12 @@ class FunctionTest(test.TestCase):
|
||||
self.assertAllEqual(y.get_shape().as_list(), [1, 1, 2, 3])
|
||||
|
||||
def testVariableReuse(self):
|
||||
|
||||
def LinearWithReuse(input_tensor, reuse=None):
|
||||
size = input_tensor.shape.dims[1]
|
||||
with variable_scope.variable_scope("linear", reuse=reuse):
|
||||
w = variable_scope.get_variable("w", shape=[size, size],
|
||||
dtype=input_tensor.dtype)
|
||||
w = variable_scope.get_variable(
|
||||
"w", shape=[size, size], dtype=input_tensor.dtype)
|
||||
return math_ops.matmul(input_tensor, w)
|
||||
|
||||
@function.Defun(dtypes.float32)
|
||||
@ -789,15 +810,19 @@ class FunctionTest(test.TestCase):
|
||||
|
||||
with session.Session() as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
output_val = sess.run(output_op,
|
||||
feed_dict={input_op: np.random.rand(32, 100)})
|
||||
output_val = sess.run(
|
||||
output_op, feed_dict={input_op: np.random.rand(32, 100)})
|
||||
self.assertEqual(output_val.shape, (32, 100))
|
||||
|
||||
def testFunctionCallInDifferentVariableScopes(self):
|
||||
|
||||
@function.Defun(dtypes.float32)
|
||||
def Foo(inputs):
|
||||
var = variable_scope.get_variable("var", shape=[10], dtype=dtypes.float32,
|
||||
initializer=init_ops.ones_initializer())
|
||||
var = variable_scope.get_variable(
|
||||
"var",
|
||||
shape=[10],
|
||||
dtype=dtypes.float32,
|
||||
initializer=init_ops.ones_initializer())
|
||||
return inputs + var
|
||||
|
||||
input_op = array_ops.placeholder(shape=[10], dtype=dtypes.float32)
|
||||
@ -813,8 +838,8 @@ class FunctionTest(test.TestCase):
|
||||
|
||||
with session.Session() as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
out1, out2 = sess.run([out1_op, out2_op],
|
||||
feed_dict={input_op: np.linspace(1, 10, 10)})
|
||||
out1, out2 = sess.run(
|
||||
[out1_op, out2_op], feed_dict={input_op: np.linspace(1, 10, 10)})
|
||||
self.assertAllEqual(out1, np.linspace(2, 11, 10))
|
||||
self.assertAllEqual(out2, np.linspace(2, 11, 10))
|
||||
|
||||
@ -852,12 +877,15 @@ class FunctionsFromProtos(test.TestCase):
|
||||
self.assertEqual(func.captured_inputs, new_func.captured_inputs)
|
||||
|
||||
def testBasic(self):
|
||||
|
||||
@function.Defun(dtypes.float32, dtypes.float32)
|
||||
def Foo(x, y):
|
||||
return x + y
|
||||
|
||||
self.expectFunctionsEqual(Foo)
|
||||
|
||||
def testGradFunc(self):
|
||||
|
||||
@function.Defun(dtypes.float32, dtypes.float32)
|
||||
def G(x, dy):
|
||||
return x * dy
|
||||
@ -865,10 +893,12 @@ class FunctionsFromProtos(test.TestCase):
|
||||
@function.Defun(dtypes.float32, grad_func=G)
|
||||
def F(x):
|
||||
return math_ops.exp(x) - math_ops.exp(-x)
|
||||
|
||||
self.expectFunctionsEqual(F, grad_func=G)
|
||||
|
||||
def testCapturedInputs(self):
|
||||
c = constant_op.constant(10, dtypes.int64)
|
||||
|
||||
@function.Defun(dtypes.int64)
|
||||
def Foo(x):
|
||||
return x + c
|
||||
@ -885,6 +915,7 @@ class FunctionsFromProtos(test.TestCase):
|
||||
self.assertEqual(len(new_func.captured_inputs), 0)
|
||||
|
||||
def testNestedFunctions(self):
|
||||
|
||||
@function.Defun(dtypes.float32)
|
||||
def Outer(x):
|
||||
|
||||
@ -958,6 +989,7 @@ class FunctionsFromProtos(test.TestCase):
|
||||
self.assertEqual(len(function._from_library(library)), 0)
|
||||
|
||||
def testFromLibraryMissingFuncDef(self):
|
||||
|
||||
@function.Defun(dtypes.float32, dtypes.float32)
|
||||
def G1(x, dy):
|
||||
return x * dy
|
||||
@ -989,6 +1021,7 @@ class FunctionsFromProtos(test.TestCase):
|
||||
function._from_library(library)
|
||||
|
||||
def testFromLibraryCyclicGradFuncs(self):
|
||||
|
||||
@function.Defun(dtypes.float32)
|
||||
def F1(x):
|
||||
return math_ops.exp(x) - math_ops.exp(-x)
|
||||
@ -1242,10 +1275,11 @@ class FunctionInlineControlTest(test.TestCase):
|
||||
inp = np.random.uniform(-1, 1, [16, 1]).astype(np.float32)
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
with session.Session(graph=g, config=cfg) as sess:
|
||||
ans = sess.run([y, dx], {x: inp},
|
||||
run_metadata=run_metadata,
|
||||
options=config_pb2.RunOptions(
|
||||
trace_level=config_pb2.RunOptions.FULL_TRACE))
|
||||
ans = sess.run(
|
||||
[y, dx], {x: inp},
|
||||
run_metadata=run_metadata,
|
||||
options=config_pb2.RunOptions(
|
||||
trace_level=config_pb2.RunOptions.FULL_TRACE))
|
||||
print(ans[0], np.sum(ans[1]))
|
||||
self.assertAllClose(ans[0], 255.971, rtol=1e-3)
|
||||
self.assertAllClose(np.sum(ans[1]), 13.0408, rtol=1e-3)
|
||||
@ -1275,8 +1309,7 @@ class ModuleFunctionTest(test.TestCase):
|
||||
def testBasic(self):
|
||||
with ops.Graph().as_default():
|
||||
a, b, c, d, e = [
|
||||
constant_op.constant(
|
||||
[[_]], dtype=dtypes.float32) for _ in range(5)
|
||||
constant_op.constant([[_]], dtype=dtypes.float32) for _ in range(5)
|
||||
]
|
||||
y = Linear(a, b, c)
|
||||
z = Linear2(a, b, c, d, e)
|
||||
@ -1295,7 +1328,8 @@ class VariableHoistingTest(test.TestCase):
|
||||
initializer=init_ops.random_uniform_initializer(seed=312),
|
||||
use_resource=use_resource)
|
||||
b = variable_scope.get_variable(
|
||||
"b", (64), initializer=init_ops.zeros_initializer(),
|
||||
"b", (64),
|
||||
initializer=init_ops.zeros_initializer(),
|
||||
use_resource=use_resource),
|
||||
return math_ops.sigmoid(math_ops.matmul(x, w) + b)
|
||||
|
||||
@ -1354,5 +1388,6 @@ class VariableHoistingTest(test.TestCase):
|
||||
self._testSimpleModel(True, use_resource=True)
|
||||
self._testSimpleModel(False, use_resource=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -148,7 +148,7 @@ def input_producer(input_tensor,
|
||||
"""
|
||||
with ops.name_scope(name, "input_producer", [input_tensor]):
|
||||
input_tensor = ops.convert_to_tensor(input_tensor, name="input_tensor")
|
||||
element_shape = input_tensor.get_shape()[1:].merge_with(element_shape)
|
||||
element_shape = input_tensor.shape[1:].merge_with(element_shape)
|
||||
if not element_shape.is_fully_defined():
|
||||
raise ValueError("Either `input_tensor` must have a fully defined shape "
|
||||
"or `element_shape` must be specified")
|
||||
@ -168,7 +168,7 @@ def input_producer(input_tensor,
|
||||
q, [enq], cancel_op=cancel_op))
|
||||
if summary_name is not None:
|
||||
summary.scalar(summary_name,
|
||||
math_ops.cast(q.size(), dtypes.float32) * (1. / capacity))
|
||||
math_ops.to_float(q.size()) * (1. / capacity))
|
||||
return q
|
||||
|
||||
|
||||
@ -465,7 +465,7 @@ def _store_sparse_tensors(tensor_list, enqueue_many, keep_input,
|
||||
def _sparse_meta_data(t, storing_op, map_op):
|
||||
if not isinstance(t, sparse_tensor.SparseTensor):
|
||||
return _SparseMetaData(False, None, None)
|
||||
rank = t.dense_shape.get_shape().with_rank(1)[0]
|
||||
rank = t.dense_shape.shape.with_rank(1)[0]
|
||||
if enqueue_many:
|
||||
rank -= 1
|
||||
# If a shared map_op was provided, use that. Otherwise use the name of
|
||||
@ -492,7 +492,7 @@ def _store_sparse_tensors(tensor_list, enqueue_many, keep_input,
|
||||
lambda: -1 * array_ops.ones(array_ops.shape(t)[0:1], dtypes.int64))
|
||||
out_tensor.set_shape([None]) # necessary when t.ndims is unknown
|
||||
return out_tensor
|
||||
if keep_input.get_shape().ndims == 1:
|
||||
if keep_input.shape.ndims == 1:
|
||||
t = sparse_ops.sparse_retain(t, keep_input)
|
||||
store_f = lambda t, name, _: _store_many_sparse(t, shared_name=name)
|
||||
elif enqueue_many:
|
||||
@ -577,13 +577,13 @@ def _validate_join(tensor_list_list):
|
||||
def _validate_keep_input(keep_input, enqueue_many):
|
||||
"""Validate `keep_input` argument to conditional batching functions."""
|
||||
keep_input = ops.convert_to_tensor(keep_input)
|
||||
if keep_input.get_shape().ndims is None:
|
||||
if keep_input.shape.ndims is None:
|
||||
raise ValueError(
|
||||
"`keep_input` dimensions must be known at graph construction.")
|
||||
if not enqueue_many and keep_input.get_shape().ndims == 1:
|
||||
if not enqueue_many and keep_input.shape.ndims == 1:
|
||||
raise ValueError(
|
||||
"`keep_input` cannot be a vector when `enqueue_many=False`.")
|
||||
if keep_input.get_shape().ndims > 1:
|
||||
if keep_input.shape.ndims > 1:
|
||||
raise ValueError("`keep_input` must be 0 or 1 dimensions.")
|
||||
return keep_input
|
||||
|
||||
@ -632,18 +632,18 @@ def _shapes(tensor_list_list, shapes, enqueue_many):
|
||||
|
||||
for tl in tensor_list_list:
|
||||
for i in xrange(len0):
|
||||
if tl[i].get_shape().ndims is None:
|
||||
if tl[i].shape.ndims is None:
|
||||
raise ValueError("Cannot infer Tensor's rank: %s" % tl[i])
|
||||
|
||||
shapes = [_merge_shapes(
|
||||
[tl[i].get_shape().as_list() for tl in tensor_list_list], enqueue_many)
|
||||
[tl[i].shape.as_list() for tl in tensor_list_list], enqueue_many)
|
||||
for i in xrange(len0)]
|
||||
return shapes
|
||||
|
||||
|
||||
def _select_which_to_enqueue(tensor_list, keep_input):
|
||||
"""Select which examples to enqueue based on vector `keep_input`."""
|
||||
select_i = math_ops.cast(keep_input, dtypes.int32)
|
||||
select_i = math_ops.to_int32(keep_input)
|
||||
tensor_list = [
|
||||
data_flow_ops.dynamic_partition(x, select_i, num_partitions=2)[1]
|
||||
for x in tensor_list]
|
||||
@ -656,7 +656,7 @@ def _enqueue_join(queue, tensor_list_list, enqueue_many, keep_input):
|
||||
enqueue_fn = queue.enqueue_many
|
||||
else:
|
||||
enqueue_fn = queue.enqueue
|
||||
if keep_input.get_shape().ndims == 1:
|
||||
if keep_input.shape.ndims == 1:
|
||||
enqueue_ops = [enqueue_fn(_select_which_to_enqueue(x, keep_input))
|
||||
for x in tensor_list_list]
|
||||
else:
|
||||
@ -673,7 +673,7 @@ def _enqueue(queue, tensor_list, threads, enqueue_many, keep_input):
|
||||
enqueue_fn = queue.enqueue_many
|
||||
else:
|
||||
enqueue_fn = queue.enqueue
|
||||
if keep_input.get_shape().ndims == 1:
|
||||
if keep_input.shape.ndims == 1:
|
||||
enqueue_ops = [
|
||||
enqueue_fn(_select_which_to_enqueue(tensor_list, keep_input))] * threads
|
||||
else:
|
||||
@ -707,8 +707,7 @@ def _batch(tensors, batch_size, keep_input, num_threads=1, capacity=32,
|
||||
capacity=capacity, dtypes=types, shapes=shapes, shared_name=shared_name)
|
||||
_enqueue(queue, tensor_list, num_threads, enqueue_many, keep_input)
|
||||
summary.scalar("fraction_of_%d_full" % capacity,
|
||||
math_ops.cast(queue.size(), dtypes.float32) *
|
||||
(1. / capacity))
|
||||
math_ops.to_float(queue.size()) * (1. / capacity))
|
||||
|
||||
if allow_smaller_final_batch:
|
||||
dequeued = queue.dequeue_up_to(batch_size, name=name)
|
||||
@ -742,8 +741,7 @@ def _batch_join(tensors_list, batch_size, keep_input, capacity=32,
|
||||
capacity=capacity, dtypes=types, shapes=shapes, shared_name=shared_name)
|
||||
_enqueue_join(queue, tensor_list_list, enqueue_many, keep_input)
|
||||
summary.scalar("fraction_of_%d_full" % capacity,
|
||||
math_ops.cast(queue.size(), dtypes.float32) *
|
||||
(1. / capacity))
|
||||
math_ops.to_float(queue.size()) * (1. / capacity))
|
||||
|
||||
if allow_smaller_final_batch:
|
||||
dequeued = queue.dequeue_up_to(batch_size, name=name)
|
||||
@ -775,8 +773,8 @@ def _shuffle_batch(tensors, batch_size, capacity, min_after_dequeue,
|
||||
capacity=capacity, min_after_dequeue=min_after_dequeue, seed=seed,
|
||||
dtypes=types, shapes=shapes, shared_name=shared_name)
|
||||
_enqueue(queue, tensor_list, num_threads, enqueue_many, keep_input)
|
||||
full = (math_ops.cast(math_ops.maximum(0, queue.size() - min_after_dequeue),
|
||||
dtypes.float32) *
|
||||
full = (math_ops.to_float(
|
||||
math_ops.maximum(0, queue.size() - min_after_dequeue)) *
|
||||
(1. / (capacity - min_after_dequeue)))
|
||||
# Note that name contains a '/' at the end so we intentionally do not place
|
||||
# a '/' after %s below.
|
||||
@ -812,8 +810,8 @@ def _shuffle_batch_join(tensors_list, batch_size, capacity,
|
||||
capacity=capacity, min_after_dequeue=min_after_dequeue, seed=seed,
|
||||
dtypes=types, shapes=shapes, shared_name=shared_name)
|
||||
_enqueue_join(queue, tensor_list_list, enqueue_many, keep_input)
|
||||
full = (math_ops.cast(math_ops.maximum(0, queue.size() - min_after_dequeue),
|
||||
dtypes.float32) *
|
||||
full = (math_ops.to_float(
|
||||
math_ops.maximum(0, queue.size() - min_after_dequeue)) *
|
||||
(1. / (capacity - min_after_dequeue)))
|
||||
# Note that name contains a '/' at the end so we intentionally do not place
|
||||
# a '/' after %s below.
|
||||
|
Loading…
x
Reference in New Issue
Block a user