Merge pull request #11789 from vrv/branch_163213141

Branch 163213141
This commit is contained in:
Vijay Vasudevan 2017-07-26 13:09:00 -07:00 committed by GitHub
commit ae22de6596
64 changed files with 1647 additions and 322 deletions

View File

@ -149,4 +149,12 @@ ClientLibrary::GetOrCreateCompileOnlyClient(
return cl;
}
/* static */ void ClientLibrary::DestroyLocalInstances() {
ClientLibrary& client_library = Singleton();
tensorflow::mutex_lock lock(client_library.service_mutex_);
client_library.local_instances_.clear();
client_library.compile_only_instances_.clear();
}
} // namespace xla

View File

@ -93,6 +93,11 @@ class ClientLibrary {
static StatusOr<CompileOnlyClient*> GetOrCreateCompileOnlyClient(
perftools::gputools::Platform* platform = nullptr);
// Clears the local instance and compile only instance caches. The client
// pointers returned by the previous GetOrCreateLocalClient() or
// GetOrCreateCompileOnlyClient() invocations are not valid anymore.
static void DestroyLocalInstances();
private:
// Returns the singleton instance of ClientLibrary.
static ClientLibrary& Singleton();

View File

@ -1208,6 +1208,7 @@ cc_test(
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:padding",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/core:lib",
"//tensorflow/core:test_main",
],

View File

@ -1586,6 +1586,9 @@ StatusOr<bool> AlgebraicSimplifier::Run(HloModule* module) {
// module, invalidating iteration.
std::vector<HloComputation*> computations;
for (auto& comp : module->computations()) {
if (comp->IsFusionComputation()) {
continue;
}
computations.push_back(comp.get());
}
for (auto& comp : computations) {

View File

@ -268,6 +268,9 @@ StatusOr<bool> BatchNormRewriter::Run(HloModule* module) {
// module, invalidating iteration.
std::vector<HloComputation*> computations;
for (auto& comp : module->computations()) {
if (comp->IsFusionComputation()) {
continue;
}
computations.push_back(comp.get());
}
for (auto& comp : computations) {

View File

@ -1219,6 +1219,9 @@ void BufferAssigner::BuildColocatedBufferSets(
const TuplePointsToAnalysis& points_to_analysis =
buffer_liveness.points_to_analysis();
for (const HloComputation* computation : module->MakeComputationPostOrder()) {
if (computation->IsFusionComputation()) {
continue;
}
for (const HloInstruction* instruction :
computation->MakeInstructionPostOrder()) {
const HloOpcode opcode = instruction->opcode();
@ -1386,6 +1389,9 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
// their own BufferAllocation.
for (auto* computation : thread_local_computations) {
TF_RET_CHECK(computation != module->entry_computation());
if (computation->IsFusionComputation()) {
continue;
}
TF_RETURN_IF_ERROR(AssignBuffersForComputation(
computation, module->config().debug_options(),
/*is_thread_local=*/true, colocated_buffers, colocated_allocations,

View File

@ -47,6 +47,9 @@ StatusOr<std::unique_ptr<BufferLiveness>> BufferLiveness::Run(
tensorflow::Status BufferLiveness::Analyze() {
TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module_));
for (auto& computation : module_->computations()) {
if (computation->IsFusionComputation()) {
continue;
}
// Gather all instructions whose buffers might alias other instructions into
// the set aliased_buffers_. This includes those contained as a tuple
// element in other instruction's output.

View File

@ -551,6 +551,9 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) {
// Add copies of computation root instructions, if needed.
FlatMap<const HloComputation*, ShapeTree<bool>> while_body_read_only_indices;
for (auto& computation : module->computations()) {
if (computation->IsFusionComputation()) {
continue;
}
VLOG(2) << "computation " << computation->name();
InstructionCopier root_copier(computation->root_instruction(),
/*copy_users=*/{});

View File

@ -519,6 +519,9 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
new std::map<HloInstruction*, string>());
for (auto embedded_computation :
computation->MakeEmbeddedComputationsList()) {
if (embedded_computation->IsFusionComputation()) {
continue;
}
auto parallel_computation_iter =
parallel_computations.find(embedded_computation);
// All parallel computations are considered to be an entry computation for
@ -591,6 +594,9 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
for (auto embedded_computation :
computation->MakeEmbeddedComputationsList()) {
if (embedded_computation->IsFusionComputation()) {
continue;
}
TF_RETURN_IF_ERROR(
ir_emitter
.EmitComputation(embedded_computation,
@ -755,6 +761,9 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
HloComputation* computation = module->entry_computation();
for (auto embedded_computation :
computation->MakeEmbeddedComputationsList()) {
if (embedded_computation->IsFusionComputation()) {
continue;
}
TF_RETURN_IF_ERROR(
ir_emitter
.EmitComputation(embedded_computation,

View File

@ -125,6 +125,9 @@ StatusOr<bool> ParallelizationPreparation::Run(HloModule* module) {
TF_ASSIGN_OR_RETURN(auto points_to_analysis,
TuplePointsToAnalysis::Run(module));
for (auto& computation : module->computations()) {
if (computation->IsFusionComputation()) {
continue;
}
HloInstruction* root = computation->root_instruction();
// Copy root instruction if it does not define its own top-level buffer.
// TODO(b/32885001) Remove these copies (at least for the unambiguous case).

View File

@ -293,12 +293,19 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) {
StatusOr<bool> FusionMerger::Run(HloModule* module) {
bool changed = false;
VLOG(2) << "FusionMerger for module: " << module->name();
std::vector<HloComputation*> computations;
for (auto& computation : module->computations()) {
if (computation->IsFusionComputation()) {
continue;
}
computations.push_back(computation.get());
}
for (auto& computation : computations) {
VLOG(1) << "Before running FusionInstructionMerger for computation: "
<< computation->name();
XLA_VLOG_LINES(3, computation->ToString());
FusionInstructionMerger fusion_merger(computation.get());
FusionInstructionMerger fusion_merger(computation);
TF_RETURN_IF_ERROR(fusion_merger.Run());
changed |= fusion_merger.changed();

View File

@ -120,7 +120,8 @@ GpuHloOrdering::GpuHloOrdering(
// do that yet since it's hard to ensure that the order here is the order used
// by IrEmitterNested. And mismatched ordering bugs would be hard to find.
for (auto& computation : module->computations()) {
if (computation.get() != module->entry_computation()) {
if (computation.get() != module->entry_computation() &&
!computation->IsFusionComputation()) {
predecessors_.emplace(computation.get(),
computation->ComputeReachability());
}

View File

@ -42,6 +42,9 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
bool changed = false;
for (auto& computation : module->computations()) {
if (computation->IsFusionComputation()) {
continue;
}
for (auto instruction : computation->MakeInstructionPostOrder()) {
// Skip dead code.
if (instruction->user_count() == 0 &&

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/user_computation.h"
#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/compiler/xla/statusor.h"
@ -329,7 +330,7 @@ TEST_F(HloCostAnalysisTest, MatmulAndConvolutionCanBeTheSameComputation) {
EXPECT_EQ(conv_analysis.flop_count(), matmul_analysis.flop_count());
}
using FusionCostAnalysis = ::testing::Test;
using FusionCostAnalysis = HloTestBase;
TEST_F(FusionCostAnalysis, LoopFusion) {
// Do this 4 times with different per-second rates to test the computation of
@ -345,32 +346,32 @@ TEST_F(FusionCostAnalysis, LoopFusion) {
// mul = Mul(exp, C3)
// sub = Sub(mul, clamp)
// tuple = Tuple({sub, sub, mul, C1})
auto c1 = HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
/*from=*/0.0f, /*to=*/1.0f, /*rows=*/2, /*cols=*/2));
auto c2 = HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
/*from=*/1.0f, /*to=*/2.0f, /*rows=*/2, /*cols=*/2));
auto c3 = HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
/*from=*/2.0f, /*to=*/3.0f, /*rows=*/2, /*cols=*/2));
HloComputation::Builder builder(TestName());
auto c1 = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
/*from=*/0.0f, /*to=*/1.0f, /*rows=*/2, /*cols=*/2)));
auto c2 = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
/*from=*/1.0f, /*to=*/2.0f, /*rows=*/2, /*cols=*/2)));
auto c3 = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
/*from=*/2.0f, /*to=*/3.0f, /*rows=*/2, /*cols=*/2)));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, c1, c2));
auto clamp = builder.AddInstruction(
HloInstruction::CreateTernary(r2f32, HloOpcode::kClamp, c2, add, add));
auto exp = builder.AddInstruction(
HloInstruction::CreateUnary(r2f32, HloOpcode::kExp, add));
auto mul = builder.AddInstruction(
HloInstruction::CreateBinary(r2f32, HloOpcode::kMultiply, exp, c3));
auto sub = builder.AddInstruction(
HloInstruction::CreateBinary(r2f32, HloOpcode::kSubtract, mul, clamp));
auto tuple = HloInstruction::CreateTuple({sub, sub, mul, c1});
auto add = HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, c1.get(),
c2.get());
auto clamp = HloInstruction::CreateTernary(r2f32, HloOpcode::kClamp,
c2.get(), add.get(), add.get());
auto exp = HloInstruction::CreateUnary(r2f32, HloOpcode::kExp, add.get());
auto mul = HloInstruction::CreateBinary(r2f32, HloOpcode::kMultiply,
exp.get(), c3.get());
auto sub = HloInstruction::CreateBinary(r2f32, HloOpcode::kSubtract,
mul.get(), clamp.get());
auto tuple = HloInstruction::CreateTuple(
{sub.get(), sub.get(), mul.get(), c1.get()});
auto fusion = HloInstruction::CreateFusion(
r2f32, HloInstruction::FusionKind::kLoop, tuple.get());
fusion->FuseInstruction(sub.get());
fusion->FuseInstruction(mul.get());
fusion->FuseInstruction(exp.get());
fusion->FuseInstruction(clamp.get());
fusion->FuseInstruction(add.get());
HloModule module(TestName());
auto* computation = module.AddEntryComputation(builder.Build());
auto* fusion = computation->CreateFusionInstruction(
{sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop);
// The time given these rates at i == 0 is exactly even among the properties
// at 1.0 seconds. For other values, one of the rates is slower so that it
@ -398,18 +399,21 @@ TEST_F(FusionCostAnalysis, NoLayout) {
Shape shape_without_layout = shape_with_layout;
shape_without_layout.clear_layout();
auto c1 = HloInstruction::CreateConstant(
Literal::CreateR4FromArray4D(Array4D<float>(2, 3, 4, 5)));
auto c2 = HloInstruction::CreateConstant(Literal::CreateR1<float>({1, 2, 3}));
HloComputation::Builder builder(TestName());
auto c1 = builder.AddInstruction(HloInstruction::CreateConstant(
Literal::CreateR4FromArray4D(Array4D<float>(2, 3, 4, 5))));
auto c2 = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR1<float>({1, 2, 3})));
auto broadcast =
HloInstruction::CreateBroadcast(shape_without_layout, c2.get(), {1});
auto add = HloInstruction::CreateBinary(shape_with_layout, HloOpcode::kAdd,
c1.get(), broadcast.get());
auto broadcast = builder.AddInstruction(
HloInstruction::CreateBroadcast(shape_without_layout, c2, {1}));
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
shape_with_layout, HloOpcode::kAdd, c1, broadcast));
auto fusion = HloInstruction::CreateFusion(
shape_with_layout, HloInstruction::FusionKind::kLoop, add.get());
fusion->FuseInstruction(broadcast.get());
HloModule module(TestName());
auto* computation = module.AddEntryComputation(builder.Build());
auto* fusion = computation->CreateFusionInstruction(
{add, broadcast}, HloInstruction::FusionKind::kLoop);
HloCostAnalysis fusion_analysis(ShapeSize);
ASSERT_IS_OK(fusion->Accept(&fusion_analysis));

View File

@ -92,6 +92,9 @@ bool CombineConstants(HloComputation* computation, bool is_layout_sensitive) {
StatusOr<bool> HloCSE::Run(HloModule* module) {
bool changed = false;
for (auto& computation : module->computations()) {
if (computation->IsFusionComputation()) {
continue;
}
changed |= CombineConstants(computation.get(), is_layout_sensitive_);
std::list<HloInstruction*> post_order =

View File

@ -38,6 +38,9 @@ StatusOr<bool> HloDCE::Run(HloModule* module) {
bool changed = false;
for (auto& computation : module->computations()) {
if (computation->IsFusionComputation()) {
continue;
}
std::unordered_set<HloInstruction*> live_instructions;
TF_RETURN_IF_ERROR(computation->root_instruction()->Accept(
[&live_instructions](HloInstruction* instruction) {

View File

@ -560,19 +560,20 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
HloInstruction* instruction_to_fuse) {
CHECK_EQ(opcode_, HloOpcode::kFusion);
CHECK(instruction_to_fuse->IsFusable());
if (GetModule()) {
XLA_VLOG_LINES(1, GetModule()->ToString());
}
HloInstruction* clone = nullptr;
if (fused_instructions_computation_ == nullptr) {
if (called_computations_.empty()) {
// New fusion instruction.
auto builder = HloComputation::Builder("fused_computation", true);
builder.AddInstruction(instruction_to_fuse->Clone(/*suffix=*/""));
fused_instructions_computation_ = builder.Build();
called_computations_.push_back(
CHECK_NOTNULL(GetModule())->AddEmbeddedComputation(builder.Build()));
clone = fused_expression_root();
clone->parent_fusion_instruction_ = this;
} else {
CHECK(fused_instructions_computation_ != nullptr &&
fused_instructions_computation_->IsFusionComputation());
clone = fused_instructions_computation_->AddInstruction(
clone = fused_instructions_computation()->AddInstruction(
instruction_to_fuse->Clone(/*suffix=*/""));
clone->parent_fusion_instruction_ = this;
// instruction_to_fuse is necessarily an operand of the fusion instruction.
@ -583,7 +584,7 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
CHECK(std::find(operands_.begin(), operands_.end(), instruction_to_fuse) !=
operands_.end());
const std::vector<HloInstruction*>& fused_parameters_ =
fused_instructions_computation_->parameter_instructions();
fused_instructions_computation()->parameter_instructions();
for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) {
if (instruction_to_fuse == operands_[operand_num]) {
// replace the fused parameter instruction's uses with the clone.
@ -593,7 +594,7 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
// Remove the corresponding fused parameter and operand from their
// respective vectors.
TF_CHECK_OK(
fused_instructions_computation_->RemoveParameter(operand_num));
fused_instructions_computation()->RemoveParameter(operand_num));
operands_.erase(operands_.begin() + operand_num);
break;
}
@ -605,7 +606,7 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
// Reread the parameters in the computation.
const std::vector<HloInstruction*>& fused_parameters_ =
fused_instructions_computation_->parameter_instructions();
fused_instructions_computation()->parameter_instructions();
// Add each operand of the clone as an operand of the fusion instruction. A
// complication is that some clone operands may already be operands of the
@ -638,7 +639,7 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
CreateParameter(param_no, operand->shape(), param_name);
param_instruction->parent_fusion_instruction_ = this;
fused_param = fused_instructions_computation_->AddParameter(
fused_param = fused_instructions_computation()->AddParameter(
std::move(param_instruction));
AppendOperand(operand);
}
@ -652,7 +653,6 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
called_computations_.push_back(computation);
}
}
return clone;
}
@ -663,17 +663,15 @@ RandomDistribution HloInstruction::random_distribution() const {
void HloInstruction::CheckFusionInstruction() const {
CHECK_EQ(opcode_, HloOpcode::kFusion);
CHECK(fused_instructions_computation_ != nullptr &&
fused_instructions_computation_->IsFusionComputation());
const std::list<std::unique_ptr<HloInstruction>>& fused_instructions_ =
fused_instructions_computation_->instructions();
fused_instructions_computation()->instructions();
// All instructions owned by this fusion instruction must be fused, and the
// parent fusion instruction of the fused instructions must be 'this'.
for (auto& instruction : fused_instructions_) {
CHECK(instruction->IsFused());
CHECK_EQ(this, instruction->fusion_instruction());
CHECK_EQ(fused_instructions_computation_.get(), instruction->parent())
CHECK_EQ(fused_instructions_computation(), instruction->parent())
<< instruction->ToString();
}
@ -976,8 +974,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
CHECK_EQ(opcode_, HloOpcode::kFusion);
CHECK(parent() != nullptr);
CHECK(fused_instructions_computation_ != nullptr &&
fused_instructions_computation_->IsFusionComputation());
auto new_instruction =
WrapUnique(new HloInstruction(HloOpcode::kFusion, shape));
@ -992,9 +988,9 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands(
// fused instructions.
std::vector<HloInstruction*> new_fused_parameters;
const std::vector<HloInstruction*>& fused_parameters_ =
fused_instructions_computation_->parameter_instructions();
fused_instructions_computation()->parameter_instructions();
const std::list<std::unique_ptr<HloInstruction>>& fused_instructions_ =
fused_instructions_computation_->instructions();
fused_instructions_computation()->instructions();
for (HloInstruction* old_fused_parameter : fused_parameters_) {
new_fused_instructions.push_back(old_fused_parameter->Clone());
@ -1028,7 +1024,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands(
}
new_instruction->fusion_kind_ = fusion_kind_;
auto computation_builder = HloComputation::Builder(
fused_instructions_computation_->name() + ".clone", true);
fused_instructions_computation()->name() + ".clone", true);
// We iterated the fusion instructions in reverse post order which means
// that we must reverse our new list of fusion instructions.
for (auto new_fused_instruction_iter = new_fused_instructions.rbegin();
@ -1037,8 +1033,10 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands(
computation_builder.AddInstruction(std::move(*new_fused_instruction_iter));
}
auto fused_root_ = fused_expression_root();
new_instruction->fused_instructions_computation_ =
computation_builder.Build(FindOrDie(old_to_new, fused_root_));
new_instruction->called_computations_.push_back(
CHECK_NOTNULL(GetModule())
->AddEmbeddedComputation(
computation_builder.Build(FindOrDie(old_to_new, fused_root_))));
new_instruction->set_parent(parent());
new_instruction->CheckFusionInstruction();
return new_instruction;
@ -1769,7 +1767,10 @@ bool HloInstruction::IsFusable() const {
HloComputation* HloInstruction::fused_instructions_computation() const {
CHECK_EQ(opcode_, HloOpcode::kFusion);
return fused_instructions_computation_.get();
CHECK(!called_computations_.empty());
auto* fused_instructions_computation = called_computations_.front();
CHECK(fused_instructions_computation->IsFusionComputation());
return fused_instructions_computation;
}
HloInstruction* HloInstruction::fusion_instruction() const {
@ -1779,32 +1780,24 @@ HloInstruction* HloInstruction::fusion_instruction() const {
HloInstruction* HloInstruction::fused_expression_root() const {
CHECK_EQ(opcode_, HloOpcode::kFusion);
CHECK(fused_instructions_computation_ != nullptr &&
fused_instructions_computation_->IsFusionComputation());
return fused_instructions_computation_->root_instruction();
return fused_instructions_computation()->root_instruction();
}
HloInstruction* HloInstruction::fused_parameter(int64 parameter_number) const {
CHECK_EQ(opcode_, HloOpcode::kFusion);
CHECK(fused_instructions_computation_ != nullptr &&
fused_instructions_computation_->IsFusionComputation());
return fused_instructions_computation_->parameter_instruction(
return fused_instructions_computation()->parameter_instruction(
parameter_number);
}
const std::vector<HloInstruction*>& HloInstruction::fused_parameters() const {
CHECK_EQ(opcode_, HloOpcode::kFusion);
CHECK(fused_instructions_computation_ != nullptr &&
fused_instructions_computation_->IsFusionComputation());
return fused_instructions_computation_->parameter_instructions();
return fused_instructions_computation()->parameter_instructions();
}
const std::list<std::unique_ptr<HloInstruction>>&
HloInstruction::fused_instructions() const {
CHECK_EQ(opcode_, HloOpcode::kFusion);
CHECK(fused_instructions_computation_ != nullptr &&
fused_instructions_computation_->IsFusionComputation());
return fused_instructions_computation_->instructions();
return fused_instructions_computation()->instructions();
}
HloInstruction::HloInstruction(HloOpcode opcode, const Shape& shape)
@ -2039,7 +2032,7 @@ static Status PostOrderDFS(HloInstruction* root, DfsHloVisitor* visitor,
Status HloInstruction::Accept(DfsHloVisitor* visitor, bool call_finish_visit,
bool ignore_control_predecessors) {
VLOG(2) << "HloInstruction::Accept(" << name() << ")";
VLOG(3) << "HloInstruction::Accept(" << name() << ")";
TF_RETURN_IF_ERROR(
PostOrderDFS(this, visitor, nullptr, ignore_control_predecessors));
if (call_finish_visit) {
@ -2055,8 +2048,11 @@ Status HloInstruction::AcceptWithOperandOrder(
TF_RETURN_IF_ERROR(PostOrderDFS(this, visitor, &operand_order,
/*ignore_control_predecessors=*/false));
if (call_finish_visit) {
VLOG(3) << "HloInstruction::AcceptWithOperandOrder BEFORE FINISH VISIT";
TF_RETURN_IF_ERROR(visitor->FinishVisit(this));
VLOG(3) << "HloInstruction::AcceptWithOperandOrder AFTER FINISH VISIT";
}
VLOG(2) << "HloInstruction::AcceptWithOperandOrder EXIT";
return Status::OK();
}
@ -2458,6 +2454,7 @@ HloModule* HloInstruction::GetModule() const {
}
void HloInstruction::UniquifyName(NameUniquer* name_uniquer) {
string parent_str = parent() == nullptr ? "noparent" : parent()->name();
name_ = name_uniquer->GetUniqueName(name_);
}

View File

@ -935,10 +935,6 @@ class HloInstruction {
// padding of this pad instruction. Only set for pad instructions.
std::unique_ptr<PaddingConfig> padding_config_;
// The computation that stores of instructions fused into this fusion
// instruction. Only set for fusion instructions.
std::unique_ptr<HloComputation> fused_instructions_computation_;
// If this instruction is fused into a fusion instruction, this field points
// to the fusion instruction.
HloInstruction* parent_fusion_instruction_ = nullptr;

View File

@ -557,78 +557,89 @@ TEST_F(HloInstructionTest, PostProcessAllVisitedNodes) {
}
TEST_F(HloInstructionTest, SingletonFusionOp) {
HloComputation::Builder builder(TestName());
// Create a fusion instruction containing a single unary operation.
auto constant =
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f));
auto exp =
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant.get());
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
auto exp = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
HloModule module(TestName());
auto* computation = module.AddEntryComputation(builder.Build());
auto* fusion = computation->CreateFusionInstruction(
{exp}, HloInstruction::FusionKind::kLoop);
auto fusion = HloInstruction::CreateFusion(
r0f32_, HloInstruction::FusionKind::kLoop, exp.get());
EXPECT_THAT(fusion->operands(), ElementsAre(constant.get()));
EXPECT_THAT(constant->users(), UnorderedElementsAre(fusion.get(), exp.get()));
EXPECT_THAT(fusion->operands(), ElementsAre(constant));
EXPECT_THAT(constant->users(), ElementsAre(fusion));
}
TEST_F(HloInstructionTest, BinaryFusionOp) {
HloComputation::Builder builder(TestName());
// Create a fusion instruction containing a single binary operation.
auto constant1 =
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f));
auto constant2 =
HloInstruction::CreateConstant(Literal::CreateR0<float>(42.1f));
auto add = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
constant1.get(), constant2.get());
auto constant1 = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
auto constant2 = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(42.1f)));
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
r0f32_, HloOpcode::kAdd, constant1, constant2));
HloModule module(TestName());
auto* computation = module.AddEntryComputation(builder.Build());
auto* fusion = computation->CreateFusionInstruction(
{add}, HloInstruction::FusionKind::kLoop);
auto fusion = HloInstruction::CreateFusion(
r0f32_, HloInstruction::FusionKind::kLoop, add.get());
EXPECT_THAT(fusion->operands(),
ElementsAre(constant1.get(), constant2.get()));
EXPECT_THAT(constant1->users(),
UnorderedElementsAre(fusion.get(), add.get()));
EXPECT_THAT(constant2->users(),
UnorderedElementsAre(fusion.get(), add.get()));
EXPECT_THAT(fusion->operands(), ElementsAre(constant1, constant2));
EXPECT_THAT(constant1->users(), ElementsAre(fusion));
EXPECT_THAT(constant2->users(), ElementsAre(fusion));
}
TEST_F(HloInstructionTest, ChainFusionOp) {
HloComputation::Builder builder(TestName());
// Create a chain of fused unary ops.
auto constant =
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f));
auto exp1 =
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant.get());
auto exp2 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1.get());
auto exp3 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp2.get());
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
auto exp1 = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
auto exp2 = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1));
auto exp3 = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp2));
auto fusion = HloInstruction::CreateFusion(
r0f32_, HloInstruction::FusionKind::kLoop, exp3.get());
fusion->FuseInstruction(exp2.get());
fusion->FuseInstruction(exp1.get());
HloModule module(TestName());
auto* computation = module.AddEntryComputation(builder.Build());
auto* fusion = computation->CreateFusionInstruction(
{exp3, exp2, exp1}, HloInstruction::FusionKind::kLoop);
EXPECT_THAT(fusion->operands(), ElementsAre(constant.get()));
EXPECT_THAT(constant->users(),
UnorderedElementsAre(fusion.get(), exp1.get()));
EXPECT_THAT(fusion->operands(), ElementsAre(constant));
EXPECT_THAT(constant->users(), ElementsAre(fusion));
}
TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) {
HloComputation::Builder builder(TestName());
// Create a chain of fused unary ops.
auto constant =
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f));
auto exp1 =
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant.get());
auto exp2 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1.get());
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
auto exp1 = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
auto exp2 = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1));
OpMetadata metadata;
metadata.set_op_name("tf_op");
exp1->set_metadata(metadata);
exp2->set_metadata(metadata);
auto fusion = HloInstruction::CreateFusion(
r0f32_, HloInstruction::FusionKind::kLoop, exp2.get());
auto* fused = fusion->FuseInstruction(exp1.get());
HloModule module(TestName());
auto* computation = module.AddEntryComputation(builder.Build());
auto* fusion = computation->CreateFusionInstruction(
{exp2, exp1}, HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(protobuf_util::ProtobufEquals(metadata, fusion->metadata()));
EXPECT_TRUE(protobuf_util::ProtobufEquals(metadata, fused->metadata()));
EXPECT_TRUE(protobuf_util::ProtobufEquals(
metadata, fusion->fused_expression_root()->metadata()));
EXPECT_TRUE(protobuf_util::ProtobufEquals(
metadata, fusion->fused_expression_root()->operand(0)->metadata()));
}
TEST_F(HloInstructionTest, FusionOpWithCalledComputations) {
HloComputation::Builder builder(TestName());
// Create a fusion instruction containing a single unary operation.
const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
@ -642,33 +653,36 @@ TEST_F(HloInstructionTest, FusionOpWithCalledComputations) {
std::unique_ptr<HloComputation> computation_x = make_map_computation();
std::unique_ptr<HloComputation> computation_y = make_map_computation();
auto constant =
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f));
auto map_1_x =
HloInstruction::CreateMap(scalar_shape, {constant.get()},
computation_x.get(), /*static_operands=*/{});
auto map_2_x =
HloInstruction::CreateMap(scalar_shape, {map_1_x.get()},
computation_x.get(), /*static_operands=*/{});
auto map_3_y =
HloInstruction::CreateMap(scalar_shape, {map_2_x.get()},
computation_y.get(), /*static_operands=*/{});
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
auto map_1_x = builder.AddInstruction(HloInstruction::CreateMap(
scalar_shape, {constant}, computation_x.get(), /*static_operands=*/{}));
auto map_2_x = builder.AddInstruction(HloInstruction::CreateMap(
scalar_shape, {map_1_x}, computation_x.get(), /*static_operands=*/{}));
auto map_3_y = builder.AddInstruction(HloInstruction::CreateMap(
scalar_shape, {map_2_x}, computation_y.get(), /*static_operands=*/{}));
auto fusion = HloInstruction::CreateFusion(
scalar_shape, HloInstruction::FusionKind::kLoop, map_3_y.get());
EXPECT_THAT(fusion->called_computations(), ElementsAre(computation_y.get()));
fusion->FuseInstruction(map_2_x.get());
HloModule module(TestName());
auto* computation = module.AddEntryComputation(builder.Build());
auto* fusion = computation->CreateFusionInstruction(
{map_3_y}, HloInstruction::FusionKind::kLoop);
auto* fused_computation = fusion->fused_instructions_computation();
EXPECT_THAT(fusion->called_computations(),
ElementsAre(computation_y.get(), computation_x.get()));
ElementsAre(fused_computation, computation_y.get()));
fusion->FuseInstruction(map_1_x.get());
EXPECT_THAT(fusion->called_computations(),
ElementsAre(computation_y.get(), computation_x.get()));
fusion->FuseInstruction(map_2_x);
EXPECT_THAT(
fusion->called_computations(),
ElementsAre(fused_computation, computation_y.get(), computation_x.get()));
fusion->FuseInstruction(map_1_x);
EXPECT_THAT(
fusion->called_computations(),
ElementsAre(fused_computation, computation_y.get(), computation_x.get()));
}
TEST_F(HloInstructionTest, ComplexFusionOp) {
HloComputation::Builder builder(TestName());
// Fuse all instructions in complicated expression:
//
// add = Add(C1, C2)
@ -680,35 +694,35 @@ TEST_F(HloInstructionTest, ComplexFusionOp) {
//
// Notable complexities are repeated operands in a same instruction, different
// shapes, use of value in different expressions.
auto c1 = HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f));
auto c2 = HloInstruction::CreateConstant(Literal::CreateR0<float>(2.1f));
auto c3 = HloInstruction::CreateConstant(Literal::CreateR0<float>(9.0f));
auto c1 = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
auto c2 = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(2.1f)));
auto c3 = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(9.0f)));
auto add =
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, c1.get(), c2.get());
auto clamp = HloInstruction::CreateTernary(r0f32_, HloOpcode::kClamp,
c2.get(), add.get(), add.get());
auto exp = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, add.get());
auto mul = HloInstruction::CreateBinary(r0f32_, HloOpcode::kMultiply,
exp.get(), c3.get());
auto sub = HloInstruction::CreateBinary(r0f32_, HloOpcode::kSubtract,
mul.get(), clamp.get());
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, c1, c2));
auto clamp = builder.AddInstruction(
HloInstruction::CreateTernary(r0f32_, HloOpcode::kClamp, c2, add, add));
auto exp = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, add));
auto mul = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kMultiply, exp, c3));
auto sub = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kSubtract, mul, clamp));
auto tuple =
HloInstruction::CreateTuple({sub.get(), sub.get(), mul.get(), c1.get()});
builder.AddInstruction(HloInstruction::CreateTuple({sub, sub, mul, c1}));
auto fusion = HloInstruction::CreateFusion(
r0f32_, HloInstruction::FusionKind::kLoop, tuple.get());
fusion->FuseInstruction(sub.get());
fusion->FuseInstruction(mul.get());
fusion->FuseInstruction(exp.get());
fusion->FuseInstruction(clamp.get());
fusion->FuseInstruction(add.get());
HloModule module(TestName());
auto* computation = module.AddEntryComputation(builder.Build());
auto* fusion = computation->CreateFusionInstruction(
{tuple, sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop);
// Operands in the fusion instruction's operands() vector should be in the
// order in which their users were added fused.
EXPECT_THAT(fusion->operands(), ElementsAre(c1.get(), c3.get(), c2.get()));
EXPECT_THAT(c1->users(),
UnorderedElementsAre(add.get(), tuple.get(), fusion.get()));
EXPECT_THAT(fusion->operands(), ElementsAre(c1, c3, c2));
EXPECT_THAT(c1->users(), ElementsAre(fusion));
}
// Convenience function for comparing two HloInstructions inside of
@ -864,7 +878,8 @@ TEST_F(HloInstructionTest, PartiallyElementwise) {
HloInstruction* max = builder.AddInstruction(
HloInstruction::CreateBinary(r2f32, HloOpcode::kMaximum, div, broadcast));
auto computation = builder.Build();
HloModule module(TestName());
auto* computation = module.AddEntryComputation(builder.Build());
HloInstruction* fusion = computation->CreateFusionInstruction(
{max, broadcast, div, mul}, HloInstruction::FusionKind::kLoop);
EXPECT_FALSE(fusion->IsElementwise());
@ -906,7 +921,8 @@ TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) {
HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary(
r1f32, HloOpcode::kSubtract, min, broadcast));
auto computation = builder.Build();
HloModule module(TestName());
auto* computation = module.AddEntryComputation(builder.Build());
HloInstruction* fusion = computation->CreateFusionInstruction(
{sub, broadcast, min}, HloInstruction::FusionKind::kLoop);
EXPECT_FALSE(fusion->IsElementwise());
@ -945,7 +961,8 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) {
HloInstruction* dot = builder.AddInstruction(
HloInstruction::CreateBinary(sout, HloOpcode::kDot, x, reshape));
auto computation = builder.Build();
HloModule module(TestName());
auto* computation = module.AddEntryComputation(builder.Build());
HloInstruction* fusion = computation->CreateFusionInstruction(
{dot, reshape}, HloInstruction::FusionKind::kTransposeDot);

View File

@ -183,6 +183,9 @@ DependencyHloOrdering::DependencyHloOrdering(const HloModule* module)
// ordering based on dependencies. ExecutesBefore will return true iff there
// exists a path in the HLO computation graph from 'a' to 'b'.
for (auto& computation : module->computations()) {
if (computation->IsFusionComputation()) {
continue;
}
predecessors_.emplace(computation.get(),
computation->ComputeReachability());
}

View File

@ -1202,6 +1202,9 @@ StatusOr<bool> HloRematerialization::Run(
// After DCE, the module sequence may include instructions which no longer
// exist.
for (const auto& computation : module->computations()) {
if (computation->IsFusionComputation()) {
continue;
}
if (sequence->at(computation.get()).size() !=
computation->instruction_count()) {
// A size mismatch between the computation instruction count and the size

View File

@ -400,6 +400,9 @@ CreateMemoryMinimizingSequence(
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
TuplePointsToAnalysis::Run(&module));
for (const auto& computation : module.computations()) {
if (computation->IsFusionComputation()) {
continue;
}
TF_ASSIGN_OR_RETURN(sequence[computation.get()],
CreateMemoryMinimizingSequence(
*computation, *points_to_analysis, size_function));
@ -410,6 +413,7 @@ CreateMemoryMinimizingSequence(
StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence(
const HloComputation& computation,
const LogicalBuffer::SizeFunction& size_function) {
CHECK(!computation.IsFusionComputation());
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
TuplePointsToAnalysis::Run(computation.parent()));
return CreateMemoryMinimizingSequence(computation, *points_to_analysis,

View File

@ -211,8 +211,17 @@ bool InstructionFusion::CanFuseOnAllPaths(
StatusOr<bool> InstructionFusion::Run(HloModule* module) {
bool changed = false;
std::vector<HloComputation*> computations;
for (auto& computation : module->computations()) {
computation_ = computation.get();
if (computation->IsFusionComputation()) {
continue;
}
computations.push_back(computation.get());
}
for (auto& computation : computations) {
CHECK(!computation->IsFusionComputation());
computation_ = computation;
// We want to be able to remove arbitrary instructions from the post order
// and also compare positions of instructions in the post order. To make

View File

@ -611,6 +611,9 @@ Status CheckLayouts(
TF_ASSIGN_OR_RETURN(auto points_to_analysis,
TuplePointsToAnalysis::Run(module));
for (auto& computation : module->computations()) {
if (computation->IsFusionComputation()) {
continue;
}
for (auto& instruction : computation->instructions()) {
// Verify every instruction has a layout and the layout is valid for the
// shape.
@ -1356,6 +1359,8 @@ StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
if (computation == module->entry_computation()) {
TF_RETURN_IF_ERROR(RunOnComputation(*entry_computation_layout_,
module->entry_computation()));
} else if (computation->IsFusionComputation()) {
continue;
} else {
ComputationLayout computation_layout(computation->ComputeProgramShape());
// Setting all embedded computations to the default layout is potentially

View File

@ -29,7 +29,11 @@ string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) {
return root;
} else {
tensorflow::strings::StrAppend(&root, separator_, *count);
// Increment lookup under old 'root' name.
(*count)++;
// Initialize count under new 'root' name.
count = &(generated_names_[root]);
*count = 1;
return root;
}
}

View File

@ -26,6 +26,9 @@ StatusOr<bool> ReducePrecisionInsertion::Run(HloModule* module) {
VLOG(1) << "Running ReducePrecisionInsertion pass on " << module->name();
for (auto& computation : module->computations()) {
if (computation->IsFusionComputation()) {
continue;
}
std::vector<HloInstruction*> instructions_to_suffix;
for (auto& instruction : computation->instructions()) {

View File

@ -312,10 +312,17 @@ StatusOr<bool> TrySinkReshapeOrTranspose(HloComputation* computation,
StatusOr<bool> ReshapeMover::Run(HloModule* module) {
bool changed = false;
for (const auto& comp : module->computations()) {
std::vector<HloComputation*> computations;
for (auto& computation : module->computations()) {
if (computation->IsFusionComputation()) {
continue;
}
computations.push_back(computation.get());
}
for (const auto& comp : computations) {
for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) {
TF_ASSIGN_OR_RETURN(bool did_change,
TrySinkReshapeOrTranspose(comp.get(), instruction));
TrySinkReshapeOrTranspose(comp, instruction));
changed |= did_change;
}
}

View File

@ -351,16 +351,15 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossFusion) {
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
root_shape, HloOpcode::kAdd, reshape0, reshape1));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
auto fusion = computation->AddInstruction(HloInstruction::CreateFusion(
add->shape(), HloInstruction::FusionKind::kLoop, add));
TF_CHECK_OK(computation->ReplaceInstruction(add, fusion));
HloModule module(TestName());
auto computation = module.AddEntryComputation(builder.Build());
computation->CreateFusionInstruction({add},
HloInstruction::FusionKind::kLoop);
EXPECT_THAT(computation->root_instruction(),
op::Fusion(op::Reshape(param0), op::Reshape(param1)));
EXPECT_TRUE(ReshapeMover().Run(module.get()).ValueOrDie());
EXPECT_TRUE(ReshapeMover().Run(&module).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Reshape(op::Fusion(param0, param1)));

View File

@ -172,7 +172,14 @@ StatusOr<bool> TransposeFolding::Run(HloModule* module) {
return tensorflow::Status::OK();
};
for (auto& comp : module->computations()) {
std::vector<HloComputation*> computations;
for (auto& computation : module->computations()) {
if (computation->IsFusionComputation()) {
continue;
}
computations.push_back(computation.get());
}
for (auto& comp : computations) {
TF_RETURN_IF_ERROR(comp->Accept(visit_fn));
}

View File

@ -135,6 +135,9 @@ TuplePointsToAnalysis::Run(const HloModule* module) {
Status TuplePointsToAnalysis::Analyze() {
points_to_.clear();
for (auto& computation : module_->computations()) {
if (computation->IsFusionComputation()) {
continue;
}
TF_RETURN_IF_ERROR(computation->Accept(this));
TF_RETURN_IF_ERROR(
PopulateDefinedBuffersAndAliases(computation->instructions()));
@ -451,6 +454,9 @@ string TuplePointsToAnalysis::ToString() const {
string output = tensorflow::strings::Printf(
"TuplePointsToSet for module %s:\n", module_->name().c_str());
for (const auto& computation : module_->computations()) {
if (computation->IsFusionComputation()) {
continue;
}
const char* entry =
computation.get() == module_->entry_computation() ? "entry " : "";
tensorflow::strings::StrAppend(&output, entry, "computation ",

View File

@ -35,23 +35,20 @@ py_library(
cuda_py_test(
name = "csiszar_divergence_test",
size = "small",
size = "medium",
srcs = ["python/kernel_tests/csiszar_divergence_test.py"],
additional_deps = [
":bayesflow_py",
"//third_party/py/numpy",
"//tensorflow/contrib/distributions:distributions_py",
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/python/ops/distributions",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:gradients",
"//tensorflow/python:linalg_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn_ops",
"//tensorflow/python:platform_test",
],
)
@ -84,12 +81,11 @@ cuda_py_test(
"//third_party/py/numpy",
"//tensorflow/contrib/distributions:distributions_py",
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/python/ops/distributions",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:variables",
],
)

View File

@ -20,22 +20,24 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib import distributions as distributions_lib
from tensorflow.contrib import layers as layers_lib
from tensorflow.contrib.bayesflow.python.ops import entropy_impl as entropy_lib
from tensorflow.contrib.bayesflow.python.ops import entropy_impl as entropy
from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_diag_lib
from tensorflow.contrib.distributions.python.ops import mvn_tril as mvn_tril_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.distributions import kullback_leibler as kullback_leibler_lib
from tensorflow.python.ops.distributions import normal as normal_lib
from tensorflow.python.ops.distributions import util as distribution_util
from tensorflow.python.platform import test
distributions = distributions_lib
layers = layers_lib
entropy = entropy_lib
class NormalNoEntropy(distributions.Normal): # pylint: disable=no-init
class NormalNoEntropy(normal_lib.Normal): # pylint: disable=no-init
"""Normal distribution without a `.entropy` method."""
def entropy(self):
@ -81,10 +83,10 @@ class ElboRatioTest(test.TestCase):
n_samples = 5000
with self.test_session():
q = distributions.MultivariateNormalDiag(
q = mvn_diag_lib.MultivariateNormalDiag(
loc=self._rng.rand(*vector_shape),
scale_diag=self._rng.rand(*vector_shape))
p = distributions.MultivariateNormalDiag(
p = mvn_diag_lib.MultivariateNormalDiag(
loc=self._rng.rand(*vector_shape),
scale_diag=self._rng.rand(*vector_shape))
@ -95,7 +97,7 @@ class ElboRatioTest(test.TestCase):
n=n_samples,
form=entropy.ELBOForms.sample,
seed=42)
actual_kl = distributions.kl_divergence(q, p)
actual_kl = kullback_leibler_lib.kl_divergence(q, p)
# Relative tolerance (rtol) chosen 2 times as large as minimim needed to
# pass.
@ -109,10 +111,10 @@ class ElboRatioTest(test.TestCase):
vector_shape = (2, 3)
with self.test_session():
q = distributions.MultivariateNormalDiag(
q = mvn_diag_lib.MultivariateNormalDiag(
loc=self._rng.rand(*vector_shape),
scale_diag=self._rng.rand(*vector_shape))
p = distributions.MultivariateNormalDiag(
p = mvn_diag_lib.MultivariateNormalDiag(
loc=self._rng.rand(*vector_shape),
scale_diag=self._rng.rand(*vector_shape))
@ -123,7 +125,7 @@ class ElboRatioTest(test.TestCase):
n=n_samples,
form=entropy.ELBOForms.analytic_entropy,
seed=42)
actual_kl = distributions.kl_divergence(q, p)
actual_kl = kullback_leibler_lib.kl_divergence(q, p)
# Relative tolerance (rtol) chosen 2 times as large as minimim needed to
# pass.
@ -135,7 +137,7 @@ class ElboRatioTest(test.TestCase):
vector_shape = (2, 3)
with self.test_session():
q = distributions.MultivariateNormalDiag(
q = mvn_diag_lib.MultivariateNormalDiag(
loc=self._rng.rand(*vector_shape),
scale_diag=self._rng.rand(*vector_shape))
@ -155,7 +157,7 @@ class EntropyShannonTest(test.TestCase):
def test_normal_entropy_default_form_uses_exact_entropy(self):
with self.test_session():
dist = distributions.Normal(loc=1.11, scale=2.22)
dist = normal_lib.Normal(loc=1.11, scale=2.22)
mc_entropy = entropy.entropy_shannon(dist, n=11)
exact_entropy = dist.entropy()
self.assertEqual(exact_entropy.get_shape(), mc_entropy.get_shape())
@ -163,7 +165,7 @@ class EntropyShannonTest(test.TestCase):
def test_normal_entropy_analytic_form_uses_exact_entropy(self):
with self.test_session():
dist = distributions.Normal(loc=1.11, scale=2.22)
dist = normal_lib.Normal(loc=1.11, scale=2.22)
mc_entropy = entropy.entropy_shannon(
dist, form=entropy.ELBOForms.analytic_entropy)
exact_entropy = dist.entropy()
@ -173,7 +175,7 @@ class EntropyShannonTest(test.TestCase):
def test_normal_entropy_sample_form_gets_approximate_answer(self):
# Tested by showing we get a good answer that is not exact.
with self.test_session():
dist = distributions.Normal(loc=1.11, scale=2.22)
dist = normal_lib.Normal(loc=1.11, scale=2.22)
mc_entropy = entropy.entropy_shannon(
dist, n=1000, form=entropy.ELBOForms.sample, seed=0)
exact_entropy = dist.entropy()
@ -193,7 +195,7 @@ class EntropyShannonTest(test.TestCase):
# NormalNoEntropy is like a Normal, but does not have .entropy method, so
# we are forced to fall back on sample entropy.
dist_no_entropy = NormalNoEntropy(loc=1.11, scale=2.22)
dist_yes_entropy = distributions.Normal(loc=1.11, scale=2.22)
dist_yes_entropy = normal_lib.Normal(loc=1.11, scale=2.22)
mc_entropy = entropy.entropy_shannon(
dist_no_entropy, n=1000, form=entropy.ELBOForms.sample, seed=0)
@ -222,15 +224,16 @@ class RenyiRatioTest(test.TestCase):
mu_true = np.array([1.0, -1.0], dtype=np.float64)
chol_true = np.array([[2.0, 0.0], [0.5, 1.0]], dtype=np.float64)
with self.test_session() as sess:
target = distributions.MultivariateNormalTriL(mu_true, chol_true)
target = mvn_tril_lib.MultivariateNormalTriL(mu_true, chol_true)
# Set up q distribution by defining mean/covariance as Variables
mu = variables.Variable(
np.zeros(mu_true.shape), dtype=mu_true.dtype, name='mu')
mat = variables.Variable(
np.zeros(chol_true.shape), dtype=chol_true.dtype, name='mat')
chol = distributions.matrix_diag_transform(mat, transform=nn_ops.softplus)
q = distributions.MultivariateNormalTriL(mu, chol)
chol = distribution_util.matrix_diag_transform(
mat, transform=nn_ops.softplus)
q = mvn_tril_lib.MultivariateNormalTriL(mu, chol)
for alpha in [0.25, 0.75]:
negative_renyi_divergence = entropy.renyi_ratio(
@ -262,7 +265,7 @@ class RenyiRatioTest(test.TestCase):
n = 1000
vector_shape = (2, 3)
with self.test_session():
q = distributions.MultivariateNormalDiag(
q = mvn_diag_lib.MultivariateNormalDiag(
loc=self._rng.rand(*vector_shape),
scale_diag=self._rng.rand(*vector_shape))
for alpha in [0.25, 0.75]:

View File

@ -36,27 +36,48 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":bijectors_py",
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/contrib/learn",
"//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:check_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:clip_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:data_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:init_ops",
"//tensorflow/python:linalg_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn",
"//tensorflow/python:nn_ops",
"//tensorflow/python:random_ops",
"//tensorflow/python:state_ops",
"//tensorflow/python:tensor_util",
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python/ops/distributions",
"//third_party/py/numpy",
"@six_archive//:six",
],
)
cuda_py_test(
name = "estimator_test",
size = "small",
srcs = ["python/kernel_tests/estimator_test.py"],
additional_deps = [
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
"//tensorflow/contrib/learn",
"//tensorflow/contrib/learn:head_test",
"//tensorflow/python/ops/distributions",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_ops",
"//tensorflow/python:nn_ops",
"//tensorflow/python:session",
],
)
cuda_py_test(
name = "distribution_test",
size = "small",

View File

@ -30,6 +30,7 @@ from tensorflow.contrib.distributions.python.ops.conditional_transformed_distrib
from tensorflow.contrib.distributions.python.ops.deterministic import *
from tensorflow.contrib.distributions.python.ops.distribution_util import matrix_diag_transform
from tensorflow.contrib.distributions.python.ops.distribution_util import softplus_inverse
from tensorflow.contrib.distributions.python.ops.estimator import *
from tensorflow.contrib.distributions.python.ops.geometric import *
from tensorflow.contrib.distributions.python.ops.inverse_gamma import *
from tensorflow.contrib.distributions.python.ops.logistic import *
@ -147,6 +148,7 @@ _allowed_symbols = [
'percentile',
'assign_exponential_moving_mean_variance',
'exponential_moving_mean_variance',
'estimator_head_distribution_regression',
]
remove_undocumented(__name__, _allowed_symbols)

View File

@ -0,0 +1,114 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for estimator.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import six
from tensorflow.contrib.distributions.python.ops import estimator as estimator_lib
from tensorflow.contrib.learn.python.learn.estimators import constants
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
from tensorflow.contrib.learn.python.learn.estimators import model_fn
from tensorflow.contrib.learn.python.learn.estimators.head_test import _assert_metrics
from tensorflow.contrib.learn.python.learn.estimators.head_test import _assert_no_variables
from tensorflow.contrib.learn.python.learn.estimators.head_test import _assert_summary_tags
from tensorflow.python.client import session
from tensorflow.python.framework import ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops.distributions import normal as normal_lib
from tensorflow.python.platform import test
class EstimatorHeadDistributionRegressionTest(test.TestCase):
def _assert_output_alternatives(self, model_fn_ops):
self.assertEquals({
None: constants.ProblemType.LINEAR_REGRESSION
}, {
k: v[0] for k, v in six.iteritems(model_fn_ops.output_alternatives)
})
def testNormalLocScaleLogits(self):
# We will bias logits[..., 1] so that: logits[..., 1]=0 implies scale=1.
scale_bias = np.log(np.expm1(1.))
def softplus(x):
return np.log1p(np.exp(x))
def actual_loss(logits, labels):
mu = actual_mean(logits)
sigma = actual_stddev(logits)
labels = np.squeeze(labels, -1)
z = (labels - mu) / sigma
loss = 0.5 * (z**2. + np.log(2. * np.pi)) + np.log(sigma)
return loss.mean()
def actual_mean(logits):
return logits[..., 0]
def actual_stddev(logits):
return softplus(logits[..., 1] + scale_bias)
def make_distribution_fn(logits):
return normal_lib.Normal(
loc=logits[..., 0],
scale=nn_ops.softplus(logits[..., 1] + scale_bias))
head = estimator_lib.estimator_head_distribution_regression(
make_distribution_fn,
logits_dimension=2)
labels = np.float32([[-1.],
[0.],
[1.]])
logits = np.float32([[0., -1],
[1, 0.5],
[-1, 1]])
with ops.Graph().as_default(), session.Session():
# Convert to tensor so we can index into head.distributions.
tflogits = ops.convert_to_tensor(logits, name="logits")
model_fn_ops = head.create_model_fn_ops(
{},
labels=labels,
mode=model_fn.ModeKeys.TRAIN,
train_op_fn=head_lib.no_op_train_fn,
logits=tflogits)
self._assert_output_alternatives(model_fn_ops)
_assert_summary_tags(self, ["loss"])
_assert_no_variables(self)
loss = actual_loss(logits, labels)
_assert_metrics(self, loss, {"loss": loss}, model_fn_ops)
# Now we verify the underlying distribution was correctly constructed.
expected_mean = logits[..., 0]
self.assertAllClose(
expected_mean,
head.distribution(tflogits).mean().eval(),
rtol=1e-6, atol=0.)
expected_stddev = softplus(logits[..., 1] + scale_bias)
self.assertAllClose(
expected_stddev,
head.distribution(tflogits).stddev().eval(),
rtol=1e-6, atol=0.)
# Should have created only one distribution.
self.assertEqual(1, len(head.distributions))
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,185 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions to bridge `Distribution`s and `tf.contrib.learn.estimator` APIs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.learn.python.learn.estimators.head import _compute_weighted_loss
from tensorflow.contrib.learn.python.learn.estimators.head import _RegressionHead
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
__all__ = [
"estimator_head_distribution_regression",
]
def estimator_head_distribution_regression(make_distribution_fn,
label_dimension=1,
logits_dimension=None,
label_name=None,
weight_column_name=None,
enable_centered_bias=False,
head_name=None):
"""Creates a `Head` for regression under a generic distribution.
Args:
make_distribution_fn: Python `callable` which returns a `tf.Distribution`
instance created using only logits.
label_dimension: Number of regression labels per example. This is the size
of the last dimension of the labels `Tensor` (typically, this has shape
`[batch_size, label_dimension]`).
logits_dimension: Number of logits per example. This is the size of the last
dimension of the logits `Tensor` (typically, this has shape
`[batch_size, logits_dimension]`).
Default value: `label_dimension`.
label_name: Python `str`, name of the key in label `dict`. Can be `None` if
label is a `Tensor` (single headed models).
weight_column_name: Python `str` defining feature column name representing
weights. It is used to down weight or boost examples during training. It
will be multiplied by the loss of the example.
enable_centered_bias: Python `bool`. If `True`, estimator will learn a
centered bias variable for each class. Rest of the model structure learns
the residual after centered bias.
head_name: Python `str`, name of the head. Predictions, summary and metrics
keys are suffixed by `"/" + head_name` and the default variable scope is
`head_name`.
Returns:
An instance of `Head` for generic regression.
"""
return _DistributionRegressionHead(
make_distribution_fn=make_distribution_fn,
label_dimension=label_dimension,
logits_dimension=logits_dimension,
label_name=label_name,
weight_column_name=weight_column_name,
enable_centered_bias=enable_centered_bias,
head_name=head_name)
class _DistributionRegressionHead(_RegressionHead):
"""Creates a _RegressionHead instance from an arbitray `Distribution`."""
def __init__(self,
make_distribution_fn,
label_dimension,
logits_dimension=None,
label_name=None,
weight_column_name=None,
enable_centered_bias=False,
head_name=None):
"""`Head` for regression.
Args:
make_distribution_fn: Python `callable` which returns a `tf.Distribution`
instance created using only logits.
label_dimension: Number of regression labels per example. This is the
size of the last dimension of the labels `Tensor` (typically, this has
shape `[batch_size, label_dimension]`).
logits_dimension: Number of logits per example. This is the size of the
last dimension of the logits `Tensor` (typically, this has shape
`[batch_size, logits_dimension]`).
Default value: `label_dimension`.
label_name: Python `str`, name of the key in label `dict`. Can be `None`
if label is a tensor (single headed models).
weight_column_name: Python `str` defining feature column name representing
weights. It is used to down weight or boost examples during training. It
will be multiplied by the loss of the example.
enable_centered_bias: Python `bool`. If `True`, estimator will learn a
centered bias variable for each class. Rest of the model structure
learns the residual after centered bias.
head_name: Python `str`, name of the head. Predictions, summary and
metrics keys are suffixed by `"/" + head_name` and the default variable
scope is `head_name`.
Raises:
TypeError: if `make_distribution_fn` is not `callable`.
"""
if not callable(make_distribution_fn):
raise TypeError("`make_distribution_fn` must be a callable function.")
self._distributions = {}
self._make_distribution_fn = make_distribution_fn
def static_value(x):
"""Returns the static value of a `Tensor` or `None`."""
return tensor_util.constant_value(ops.convert_to_tensor(x))
def concat_vectors(*args):
"""Concatenates input vectors, statically if possible."""
args_ = [static_value(x) for x in args]
if any(vec is None for vec in args_):
return array_ops.concat(args, axis=0)
return [val for vec in args_ for val in vec]
def loss_fn(labels, logits, weights=None):
"""Returns the loss of using `logits` to predict `labels`."""
d = self.distribution(logits)
labels_batch_shape = labels.shape.with_rank_at_least(1)[:-1]
labels_batch_shape = (
labels_batch_shape.as_list() if labels_batch_shape.is_fully_defined()
else array_ops.shape(labels)[:-1])
labels = array_ops.reshape(
labels,
shape=concat_vectors(labels_batch_shape, d.event_shape_tensor()))
return _compute_weighted_loss(
loss_unweighted=-d.log_prob(labels),
weight=weights)
def link_fn(logits):
"""Returns the inverse link function at `logits`."""
# Note: What the API calls a "link function" is really the inverse-link
# function, i.e., the "mean".
d = self.distribution(logits)
return d.mean()
super(_DistributionRegressionHead, self).__init__(
label_dimension=label_dimension,
loss_fn=loss_fn,
link_fn=link_fn,
logits_dimension=logits_dimension,
label_name=label_name,
weight_column_name=weight_column_name,
enable_centered_bias=enable_centered_bias,
head_name=head_name)
@property
def distributions(self):
"""Returns all distributions created by `DistributionRegressionHead`."""
return self._distributions
def distribution(self, logits, name=None):
"""Retrieves a distribution instance, parameterized by `logits`.
Args:
logits: `float`-like `Tensor` representing the parameters of the
underlying distribution.
name: The Python `str` name to given to this op.
Default value: "distribution".
Returns:
distribution: `tf.Distribution` instance parameterized by `logits`.
"""
with ops.name_scope(name, "distribution", [logits]):
d = self._distributions.get(logits, None)
if d is None:
d = self._make_distribution_fn(logits)
self._distributions[logits] = d
return d

View File

@ -19,8 +19,8 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib import linalg
from tensorflow.contrib.distributions.python.ops import bijectors
from tensorflow.contrib.distributions.python.ops import distribution_util
from tensorflow.contrib.distributions.python.ops.bijectors import AffineLinearOperator
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops
@ -189,7 +189,7 @@ class MultivariateNormalLinearOperator(
distribution=normal.Normal(
loc=array_ops.zeros([], dtype=scale.dtype),
scale=array_ops.ones([], dtype=scale.dtype)),
bijector=bijectors.AffineLinearOperator(
bijector=AffineLinearOperator(
shift=loc, scale=scale, validate_args=validate_args),
batch_shape=batch_shape,
event_shape=event_shape,

View File

@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -31,13 +32,16 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.layers import base
from tensorflow.python.layers import convolutional as convolutional_layers
from tensorflow.python.layers import core as core_layers
from tensorflow.python.layers import normalization as normalization_layers
from tensorflow.python.layers import normalization as normalization_layers
from tensorflow.python.layers import pooling as pooling_layers
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import sparse_ops
@ -1281,7 +1285,7 @@ def convolution3d_transpose(
trainable=True,
scope=None):
"""Adds a convolution3d_transpose with an optional batch normalization layer.
The function creates a variable called `weights`, representing the
kernel, that is convolved with the input. If `batch_norm_params` is `None`, a
second variable called 'biases' is added to the result of the operation.
@ -1808,6 +1812,300 @@ def layer_norm(inputs,
outputs)
class GDN(base.Layer):
"""Generalized divisive normalization layer.
Based on the papers:
"Density Modeling of Images using a Generalized Normalization
Transformation"
Johannes Ballé, Valero Laparra, Eero P. Simoncelli
https://arxiv.org/abs/1511.06281
"End-to-end Optimized Image Compression"
Johannes Ballé, Valero Laparra, Eero P. Simoncelli
https://arxiv.org/abs/1611.01704
Implements an activation function that is essentially a multivariate
generalization of a particular sigmoid-type function:
y[i] = x[i] / sqrt(beta[i] + sum_j(gamma[j, i] * x[j]))
where i and j run over channels. This implementation never sums across spatial
dimensions. It is similar to local response normalization, but more powerful,
as beta and gamma are trainable parameters.
Arguments:
inverse: If False (default), compute GDN response. If True, compute IGDN
response (one step of fixed point iteration to invert GDN; the division
is replaced by multiplication).
beta_min: Lower bound for beta, to prevent numerical error from causing
square root of zero or negative values.
gamma_init: The gamma matrix will be initialized as the identity matrix
multiplied with this value. If set to zero, the layer is effectively
initialized to the identity operation, since beta is initialized as one.
A good default setting is somewhere between 0 and 0.5.
reparam_offset: Offset added to the reparameterization of beta and gamma.
The reparameterization of beta and gamma as their square roots lets the
training slow down when their values are close to zero, which is desirable
as small values in the denominator can lead to a situation where gradient
noise on beta/gamma leads to extreme amounts of noise in the GDN
activations. However, without the offset, we would get zero gradients if
any elements of beta or gamma were exactly zero, and thus the training
could get stuck. To prevent this, we add this small constant. The default
value was empirically determined as a good starting point. Making it
bigger potentially leads to more gradient noise on the activations, making
it too small may lead to numerical precision issues.
data_format: Format of input tensor. Currently supports 'channels_first' and
'channels_last'.
trainable: Boolean, if `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
name: String, the name of the layer. Layers with the same name will
share weights, but to avoid mistakes we require reuse=True in such cases.
reuse: Boolean, whether to reuse the weights of a previous layer
by the same name.
Properties:
inverse: Boolean, whether GDN is computed (True) or IGDN (False).
data_format: Format of input tensor. Currently supports 'channels_first' and
'channels_last'.
beta: The beta parameter as defined above (1D TensorFlow tensor).
gamma: The gamma parameter as defined above (2D TensorFlow tensor).
"""
def __init__(self,
inverse=False,
beta_min=1e-6,
gamma_init=.1,
reparam_offset=2 ** -18,
data_format='channels_last',
trainable=True,
name=None,
**kwargs):
super(GDN, self).__init__(trainable=trainable, name=name, **kwargs)
self.inverse = inverse
self._beta_min = beta_min
self._gamma_init = gamma_init
self._reparam_offset = reparam_offset
self.data_format = data_format
self._channel_axis() # trigger ValueError early
self.input_spec = base.InputSpec(min_ndim=3, max_ndim=5)
def _channel_axis(self):
try:
return {'channels_first': 1, 'channels_last': -1}[self.data_format]
except KeyError:
raise ValueError('Unsupported `data_format` for GDN layer: {}.'.format(
self.data_format))
@staticmethod
def _lower_bound(inputs, bound, name=None):
"""Same as tf.maximum, but with helpful gradient for inputs < bound.
The gradient is overwritten so that it is passed through if the input is not
hitting the bound. If it is, only gradients that push `inputs` higher than
the bound are passed through. No gradients are passed through to the bound.
Args:
inputs: input tensor
bound: lower bound for the input tensor
name: name for this op
Returns:
tf.maximum(inputs, bound)
"""
with ops.name_scope(name, 'GDNLowerBound', [inputs, bound]) as scope:
inputs = ops.convert_to_tensor(inputs, name='inputs')
bound = ops.convert_to_tensor(bound, name='bound')
with ops.get_default_graph().gradient_override_map(
{'Maximum': 'GDNLowerBound'}):
return math_ops.maximum(inputs, bound, name=scope)
@ops.RegisterGradient('GDNLowerBound')
@staticmethod
def _lower_bound_grad(op, grad):
"""Gradient for `_lower_bound`.
Args:
op: the tensorflow op for which to calculate a gradient
grad: gradient with respect to the output of the op
Returns:
gradients with respect to the inputs of the op
"""
inputs = op.inputs[0]
bound = op.inputs[1]
pass_through_if = math_ops.logical_or(inputs >= bound, grad < 0)
return [math_ops.cast(pass_through_if, grad.dtype) * grad, None]
def build(self, input_shape):
channel_axis = self._channel_axis()
input_shape = tensor_shape.TensorShape(input_shape)
num_channels = input_shape[channel_axis].value
if num_channels is None:
raise ValueError('The channel dimension of the inputs to `GDN` '
'must be defined.')
self._input_rank = input_shape.ndims
self.input_spec = base.InputSpec(ndim=input_shape.ndims,
axes={channel_axis: num_channels})
pedestal = array_ops.constant(self._reparam_offset ** 2, dtype=self.dtype)
beta_bound = array_ops.constant(
(self._beta_min + self._reparam_offset ** 2) ** .5, dtype=self.dtype)
gamma_bound = array_ops.constant(self._reparam_offset, dtype=self.dtype)
def beta_initializer(shape, dtype=None, partition_info=None):
del partition_info # unused
return math_ops.sqrt(array_ops.ones(shape, dtype=dtype) + pedestal)
def gamma_initializer(shape, dtype=None, partition_info=None):
del partition_info # unused
assert len(shape) == 2
assert shape[0] == shape[1]
eye = linalg_ops.eye(shape[0], dtype=dtype)
return math_ops.sqrt(self._gamma_init * eye + pedestal)
beta = self.add_variable('reparam_beta',
shape=[num_channels],
initializer=beta_initializer,
dtype=self.dtype,
trainable=True)
beta = self._lower_bound(beta, beta_bound)
self.beta = math_ops.square(beta) - pedestal
gamma = self.add_variable('reparam_gamma',
shape=[num_channels, num_channels],
initializer=gamma_initializer,
dtype=self.dtype,
trainable=True)
gamma = self._lower_bound(gamma, gamma_bound)
self.gamma = math_ops.square(gamma) - pedestal
self.built = True
def call(self, inputs):
inputs = ops.convert_to_tensor(inputs, dtype=self.dtype)
ndim = self._input_rank
shape = self.gamma.get_shape().as_list()
gamma = array_ops.reshape(self.gamma, (ndim - 2) * [1] + shape)
# Compute normalization pool.
if self.data_format == 'channels_first':
norm_pool = nn.convolution(math_ops.square(inputs), gamma, 'VALID',
data_format='NC' + 'DHW'[-(ndim - 2):])
if ndim == 3:
norm_pool = array_ops.expand_dims(norm_pool, 2)
norm_pool = nn.bias_add(norm_pool, self.beta, data_format='NCHW')
norm_pool = array_ops.squeeze(norm_pool, [2])
elif ndim == 5:
shape = array_ops.shape(norm_pool)
norm_pool = array_ops.reshape(norm_pool, shape[:3] + [-1])
norm_pool = nn.bias_add(norm_pool, self.beta, data_format='NCHW')
norm_pool = array_ops.reshape(norm_pool, shape)
else: # ndim == 4
norm_pool = nn.bias_add(norm_pool, self.beta, data_format='NCHW')
else: # channels_last
norm_pool = nn.convolution(math_ops.square(inputs), gamma, 'VALID')
norm_pool = nn.bias_add(norm_pool, self.beta, data_format='NHWC')
norm_pool = math_ops.sqrt(norm_pool)
if self.inverse:
outputs = inputs * norm_pool
else:
outputs = inputs / norm_pool
outputs.set_shape(inputs.get_shape())
return outputs
def _compute_output_shape(self, input_shape):
channel_axis = self._channel_axis()
input_shape = tensor_shape.TensorShape(input_shape)
if not 3 <= input_shape.ndim <= 5:
raise ValueError('`input_shape` must be of rank 3 to 5, inclusive.')
if input_shape[channel_axis].value is None:
raise ValueError(
'The channel dimension of `input_shape` must be defined.')
return input_shape
def gdn(inputs,
inverse=False,
beta_min=1e-6,
gamma_init=.1,
reparam_offset=2 ** -18,
data_format='channels_last',
trainable=True,
name=None,
reuse=None):
"""Functional interface for GDN layer.
Based on the papers:
"Density Modeling of Images using a Generalized Normalization
Transformation"
Johannes Ballé, Valero Laparra, Eero P. Simoncelli
https://arxiv.org/abs/1511.06281
"End-to-end Optimized Image Compression"
Johannes Ballé, Valero Laparra, Eero P. Simoncelli
https://arxiv.org/abs/1611.01704
Implements an activation function that is essentially a multivariate
generalization of a particular sigmoid-type function:
y[i] = x[i] / sqrt(beta[i] + sum_j(gamma[j, i] * x[j]))
where i and j run over channels. This implementation never sums across spatial
dimensions. It is similar to local response normalization, but more powerful,
as beta and gamma are trainable parameters.
Arguments:
inputs: Tensor input.
inverse: If False (default), compute GDN response. If True, compute IGDN
response (one step of fixed point iteration to invert GDN; the division
is replaced by multiplication).
beta_min: Lower bound for beta, to prevent numerical error from causing
square root of zero or negative values.
gamma_init: The gamma matrix will be initialized as the identity matrix
multiplied with this value. If set to zero, the layer is effectively
initialized to the identity operation, since beta is initialized as one.
A good default setting is somewhere between 0 and 0.5.
reparam_offset: Offset added to the reparameterization of beta and gamma.
The reparameterization of beta and gamma as their square roots lets the
training slow down when their values are close to zero, which is desirable
as small values in the denominator can lead to a situation where gradient
noise on beta/gamma leads to extreme amounts of noise in the GDN
activations. However, without the offset, we would get zero gradients if
any elements of beta or gamma were exactly zero, and thus the training
could get stuck. To prevent this, we add this small constant. The default
value was empirically determined as a good starting point. Making it
bigger potentially leads to more gradient noise on the activations, making
it too small may lead to numerical precision issues.
data_format: Format of input tensor. Currently supports 'channels_first' and
'channels_last'.
trainable: Boolean, if `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
name: String, the name of the layer. Layers with the same name will
share weights, but to avoid mistakes we require reuse=True in such cases.
reuse: Boolean, whether to reuse the weights of a previous layer
by the same name.
Returns:
Output tensor.
"""
layer = GDN(inverse=inverse,
beta_min=beta_min,
gamma_init=gamma_init,
reparam_offset=reparam_offset,
data_format=data_format,
trainable=trainable,
name=name,
dtype=inputs.dtype.base_dtype,
_scope=name,
_reuse=reuse)
return layer.apply(inputs)
@add_arg_scope
def max_pool2d(inputs,
kernel_size,

View File

@ -2772,6 +2772,56 @@ class LayerNormTest(test.TestCase):
self.doOutputTest((1, 100, 100, 1))
class GDNTest(test.TestCase):
def _runGDN(self, x, shape, inverse, data_format):
inputs = array_ops.placeholder(dtypes.float32, shape)
outputs = _layers.gdn(inputs, inverse=inverse, data_format=data_format)
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
y, = sess.run([outputs], {inputs: x})
return y
def testInvalidDataFormat(self):
x = np.random.uniform(size=(1, 2, 3, 4))
with self.assertRaises(ValueError):
self._runGDN(x, x.shape, False, 'NHWC')
def testUnknownDim(self):
x = np.random.uniform(size=(1, 2, 3, 4))
with self.assertRaises(ValueError):
self._runGDN(x, 4 * [None], False, 'channels_last')
def testChannelsLast(self):
for ndim in [3, 4, 5]:
x = np.random.uniform(size=(1, 2, 3, 4)[:ndim])
y = self._runGDN(x, x.shape, False, 'channels_last')
self.assertEqual(x.shape, y.shape)
self.assertAllClose(y, x / np.sqrt(1 + .1 * (x ** 2)), rtol=0, atol=1e-6)
def testChannelsFirst(self):
# `bias_add` doesn't support NCHW on CPU.
if test.is_gpu_available(cuda_only=True):
for ndim in [3, 4, 5]:
x = np.random.uniform(size=(4, 3, 2, 1)[:ndim])
y = self._runGDN(x, x.shape, False, 'channels_first')
self.assertEqual(x.shape, y.shape)
self.assertAllClose(
y, x / np.sqrt(1 + .1 * (x ** 2)), rtol=0, atol=1e-6)
def testWrongDims(self):
for ndim in [1, 2, 6]:
x = np.random.uniform(size=(1, 2, 3, 4, 3, 2)[:ndim])
with self.assertRaises(ValueError):
self._runGDN(x, x.shape, False, 'channels_last')
def testIGDN(self):
x = np.random.uniform(size=(1, 2, 3, 4))
y = self._runGDN(x, x.shape, True, 'channels_last')
self.assertEqual(x.shape, y.shape)
self.assertAllClose(y, x * np.sqrt(1 + .1 * (x ** 2)), rtol=0, atol=1e-6)
class MaxPool2DTest(test.TestCase):
def testInvalidDataFormat(self):

View File

@ -665,6 +665,7 @@ class _RegressionHead(_SingleHead):
label_dimension,
loss_fn,
link_fn,
logits_dimension=None,
label_name=None,
weight_column_name=None,
enable_centered_bias=False,
@ -677,6 +678,10 @@ class _RegressionHead(_SingleHead):
shape `[batch_size, label_dimension]`).
loss_fn: Loss function, takes logits and labels and returns loss.
link_fn: Link function, takes a logits tensor and returns the output.
logits_dimension: Number of logits per example. This is the
size of the last dimension of the logits `Tensor` (typically, this has
shape `[batch_size, label_dimension]`).
Default value: `label_dimension`.
label_name: String, name of the key in label dict. Can be null if label
is a tensor (single headed models).
weight_column_name: A string defining feature column name representing
@ -691,7 +696,8 @@ class _RegressionHead(_SingleHead):
"""
super(_RegressionHead, self).__init__(
problem_type=constants.ProblemType.LINEAR_REGRESSION,
logits_dimension=label_dimension,
logits_dimension=(logits_dimension if logits_dimension is not None
else label_dimension),
label_name=label_name,
weight_column_name=weight_column_name,
head_name=head_name)

View File

@ -103,7 +103,7 @@ def stft(signals, frame_length, frame_step, fft_length=None,
def inverse_stft(stfts,
frame_length,
frame_step,
fft_length,
fft_length=None,
window_fn=functools.partial(window_ops.hann_window,
periodic=True),
name=None):
@ -118,7 +118,8 @@ def inverse_stft(stfts,
frame_length: An integer scalar `Tensor`. The window length in samples.
frame_step: An integer scalar `Tensor`. The number of samples to step.
fft_length: An integer scalar `Tensor`. The size of the FFT that produced
`stfts`.
`stfts`. If not provided, uses the smallest power of 2 enclosing
`frame_length`.
window_fn: A callable that takes a window length and a `dtype` keyword
argument and returns a `[window_length]` `Tensor` of samples in the
provided datatype. If set to `None`, no windowing is used.
@ -130,7 +131,8 @@ def inverse_stft(stfts,
Raises:
ValueError: If `stfts` is not at least rank 2, `frame_length` is not scalar,
`frame_step` is not scalar, or `fft_length` is not scalar.
`frame_step` is not scalar, or `fft_length` is not scalar, or
`frame_length` is greater than `fft_length`.
[stft]: https://en.wikipedia.org/wiki/Short-time_Fourier_transform
"""
@ -141,8 +143,21 @@ def inverse_stft(stfts,
frame_length.shape.assert_has_rank(0)
frame_step = ops.convert_to_tensor(frame_step, name='frame_step')
frame_step.shape.assert_has_rank(0)
fft_length = ops.convert_to_tensor(fft_length, name='fft_length')
fft_length.shape.assert_has_rank(0)
if fft_length is None:
fft_length = _enclosing_power_of_two(frame_length)
else:
fft_length = ops.convert_to_tensor(fft_length, name='fft_length')
fft_length.shape.assert_has_rank(0)
frame_length_static = tensor_util.constant_value(
frame_length)
fft_length_static = tensor_util.constant_value(fft_length)
if (frame_length_static is not None and fft_length_static is not None and
frame_length_static > fft_length_static):
raise ValueError('frame_length (%d) may not be larger than '
'fft_length (%d)' % (frame_length_static,
fft_length_static))
real_frames = spectral_ops.irfft(stfts, [fft_length])[..., :frame_length]
# Optionally window and overlap-add the inner 2 dimensions of real_frames

View File

@ -98,14 +98,16 @@ class DirectSession : public Session {
::tensorflow::Status ListDevices(
std::vector<DeviceAttributes>* response) override;
::tensorflow::Status Close() override;
::tensorflow::Status LocalDeviceManager(const DeviceMgr** output) override {
*output = device_mgr_.get();
return ::tensorflow::Status::OK();
}
void ExportCostModels(CostModelManager::CostModelMap* cost_models) {
cost_model_manager_.ExportCostModels(cost_models);
}
private:
typedef DirectSession ME;
// We create one executor and its dependent library runtime for
// every partition.
struct PerPartitionExecutorsAndLib {

View File

@ -22,6 +22,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
@ -1248,6 +1249,16 @@ TEST(DirectSessionTest, TestDirectSessionReset) {
EXPECT_EQ("Cancelled: Session has been closed.", s.ToString());
}
TEST(DirectSessionTest, LocalDeviceManager) {
SessionOptions options;
std::unique_ptr<Session> session(NewSession(options));
const DeviceMgr* mgr = nullptr;
TF_ASSERT_OK(session->LocalDeviceManager(&mgr));
ASSERT_TRUE(mgr != nullptr);
EXPECT_GT(mgr->ListDevices().size(), 0);
}
// A simple benchmark for the overhead of `DirectSession::Run()` calls
// with varying numbers of feeds/fetches.
void FeedFetchBenchmarkHelper(int num_feeds, int iters) {

View File

@ -260,6 +260,8 @@ class CallOp : public AsyncOpKernel {
done);
FunctionLibraryRuntime::Options opts;
opts.step_id = ctx->step_id();
opts.rendezvous = ctx->rendezvous();
opts.cancellation_manager = ctx->cancellation_manager();
opts.step_container = ctx->step_container();
opts.stats_collector = ctx->stats_collector();
opts.runner = ctx->runner();
@ -545,23 +547,18 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
Executor::Args exec_args;
// Inherit the step_id from the caller.
exec_args.step_id = opts.step_id;
exec_args.step_container = opts.step_container;
exec_args.rendezvous = opts.rendezvous;
exec_args.stats_collector = opts.stats_collector;
exec_args.call_frame = frame;
exec_args.cancellation_manager = opts.cancellation_manager;
exec_args.step_container = opts.step_container;
exec_args.runner = *opts.runner;
// TODO(zhifengc): we can avoid creating rendez here if we know
// there is no send/recv nodes in the graph.
auto* rendez = new IntraProcessRendezvous(device_mgr_);
exec_args.rendezvous = rendez;
item->exec->RunAsync(
// Executor args
exec_args,
// Done callback.
[item, frame, rets, rendez, done](const Status& status) {
[item, frame, rets, done](const Status& status) {
item->Unref();
rendez->Unref();
Status s = status;
if (s.ok()) {
s = frame->GetRetvals(rets);

View File

@ -104,10 +104,21 @@ Status ValidateMemoryTypes(const DeviceType& device_type, const Graph* g) {
});
}
static Node* Send(Graph* g, const string& device_name, bool host,
const Edge* edge) {
const string tensor_name =
strings::StrCat("edge_", edge->id(), "_", edge->src()->name());
// Given an Edge whose two endpoints have different memory types and
// are gonna to insert a pair of HostSend/Recv or Send/HostRecv nodes,
// GetTensorName() returns a unique string that we can use as part of
// the rendezvous key. The return string is guaranteed to be unique
// within this process. That is sufficient because EnsureMemoryTypes
// is only used on a TensorFlow graph that is gonna to be executed in
// a single tf device (hence within a single process).
static string GetTensorName(const Edge* edge) {
static std::atomic<int64> counter(0);
return strings::StrCat("memtype_", counter.fetch_add(1), "_",
edge->src()->name());
}
static Node* Send(Graph* g, const string& tensor_name,
const string& device_name, bool host, const Edge* edge) {
Node* ret;
TF_CHECK_OK(NodeBuilder(g->NewName("n"), host ? "_HostSend" : "_Send")
.Input(edge->src(), edge->src_output())
@ -115,14 +126,13 @@ static Node* Send(Graph* g, const string& device_name, bool host,
.Attr("send_device", device_name)
.Attr("send_device_incarnation", 0) // Do not care.
.Attr("recv_device", device_name)
.Attr("_hostmem_sendrecv", true)
.Finalize(g, &ret));
return ret;
}
static Node* Recv(Graph* g, const string& device_name, bool host,
const Edge* edge) {
const string tensor_name =
strings::StrCat("edge_", edge->id(), "_", edge->src()->name());
static Node* Recv(Graph* g, const string& tensor_name,
const string& device_name, bool host, const Edge* edge) {
Node* ret;
TF_CHECK_OK(
NodeBuilder(g->NewName("n"), host ? "_HostRecv" : "_Recv")
@ -131,6 +141,7 @@ static Node* Recv(Graph* g, const string& device_name, bool host,
.Attr("send_device", device_name)
.Attr("send_device_incarnation", 0)
.Attr("recv_device", device_name)
.Attr("_hostmem_sendrecv", true)
.Finalize(g, &ret));
return ret;
}
@ -171,8 +182,10 @@ Status EnsureMemoryTypes(const DeviceType& device_type,
Endpoint key{e->src()->id(), e->src_output()};
auto iter = recv_nodes.find(key);
if (iter == recv_nodes.end()) {
Node* send = Send(g, device_name, (item.sm == HOST_MEMORY), e);
recv = Recv(g, device_name, (item.dm == HOST_MEMORY), e);
const string tensor_name = GetTensorName(e);
Node* send =
Send(g, tensor_name, device_name, (item.sm == HOST_MEMORY), e);
recv = Recv(g, tensor_name, device_name, (item.dm == HOST_MEMORY), e);
if (!has_ref) {
// We only cache if there is no ref is involved.
recv_nodes[key] = recv;

View File

@ -37,6 +37,7 @@ class CancellationManager;
class GraphDef;
class OpKernel;
class ResourceMgr;
class Rendezvous;
class ScopedStepContainer;
class StepStatsCollector;
class Node;
@ -398,11 +399,10 @@ class FunctionLibraryRuntime {
//
// Does not take ownership of "rets".
struct Options {
CancellationManager* cancellation_manager = nullptr;
// The id of the step that is calling this function.
int64 step_id = 0;
// Per-step container.
Rendezvous* rendezvous = nullptr;
CancellationManager* cancellation_manager = nullptr;
ScopedStepContainer* step_container = nullptr;
StepStatsCollector* stats_collector = nullptr;

View File

@ -896,6 +896,36 @@ Status AddControlEdges(const PartitionOptions& opts,
return Status::OK();
}
// If 'ndef' is a Send or Recv, fills its attr send_device_incarnation
// if possible.
void SetIncarnation(const PartitionOptions& opts, NodeDef* ndef) {
StringPiece op(ndef->op());
if (op != "_Send" && op != "_Recv") {
// Not related to send/recv.
return;
}
string send_device;
if (!GetNodeAttr(*ndef, "send_device", &send_device).ok()) {
// No known send_device. The runtime will detect it later.
return;
}
int64 incarnation = opts.get_incarnation(send_device);
AddNodeAttr("send_device_incarnation", incarnation, ndef);
}
// Sets attribute send_device_incarnation of all Send/Recv nodes in
// 'gdef', if possible.
void SetIncarnation(const PartitionOptions& opts, GraphDef* gdef) {
for (NodeDef& ndef : *gdef->mutable_node()) {
SetIncarnation(opts, &ndef);
}
for (FunctionDef& fdef : *gdef->mutable_library()->mutable_function()) {
for (NodeDef& ndef : *fdef.mutable_node_def()) {
SetIncarnation(opts, &ndef);
}
}
}
Status Partition(const PartitionOptions& opts, Graph* g,
std::unordered_map<string, GraphDef>* partitions) {
Status status;
@ -1130,10 +1160,15 @@ Status Partition(const PartitionOptions& opts, Graph* g,
}
}
// Set versions and function library
// Set versions, function library and send/recv incarnation.
for (auto& it : *partitions) {
it.second.mutable_versions()->CopyFrom(g->versions());
*it.second.mutable_library() = g->flib_def().ToProto();
GraphDef* gdef = &it.second;
*gdef->mutable_versions() = g->versions();
*gdef->mutable_library() = g->flib_def().ToProto();
// Traverse the graph to fill every send/recv op's incarnation
// information.
SetIncarnation(opts, gdef);
}
// Set the start times for recvs at the very end.

View File

@ -114,7 +114,9 @@ cc_test(
deps = [
":single_machine",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:resource_variable_ops",
"//tensorflow/cc:scope",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib_proto_parsing",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",

View File

@ -73,8 +73,6 @@ SingleMachine::~SingleMachine() {
// when we delete the session.
thread_pool_.reset();
Reset(options_, {}).IgnoreError();
CHECK(already_created);
already_created = false;
}
@ -277,11 +275,9 @@ Status SingleMachine::ResetSession() {
// Make sure the session is properly closed
TF_RETURN_IF_ERROR(Shutdown());
// We need to Reset the session to ensure that all the variables are
// deleted. But first we need to delete the session since Reset()
// deletes some of the containers referenced by the session.
// Destroying the object deletes all its varibles as well. This is only true
// for DirectSession.
session_.reset();
TF_RETURN_IF_ERROR(Reset(options_, {}));
}
LOG(INFO) << "Starting new session";

View File

@ -15,7 +15,10 @@ limitations under the License.
#include "tensorflow/core/grappler/clusters/single_machine.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/resource_variable_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/framework/cost_graph.pb.h"
#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
@ -24,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/queue_runner.pb.h"
namespace tensorflow {
namespace grappler {
@ -349,6 +353,7 @@ TEST_F(SingleMachineTest, InitializationMemory) {
}
namespace {
template <class T>
inline void SetNodeAttr(const string& key, const T& value, NodeDef* node) {
AttrValue attr_value;
@ -463,6 +468,124 @@ TEST_F(SingleMachineTest, PersistentMemory) {
EXPECT_TRUE(found_hashtable);
}
#if defined(PLATFORM_GOOGLE)
namespace {
SessionOptions GetSessionOption(int num_cpu_cores, int num_gpus) {
SessionOptions options;
// Copied from single_machine.h
(*options.config.mutable_device_count())["CPU"] = 1;
if (num_gpus > 0) {
(*options.config.mutable_device_count())["GPU"] = num_gpus;
}
CHECK_GE(num_cpu_cores, 1);
options.config.set_intra_op_parallelism_threads(num_cpu_cores);
options.config.add_session_inter_op_thread_pool()->set_num_threads(
num_cpu_cores);
return options;
}
Status GetDeviceMemoryStats(
const SessionOptions& session_option,
std::unordered_map<string, AllocatorStats>* allocator_stats_by_device) {
std::vector<Device*> devices;
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(session_option,
"" /* name_prefix */, &devices));
allocator_stats_by_device->clear();
for (Device* device : devices) {
AllocatorStats stats;
auto* allocator = device->GetAllocator(AllocatorAttributes());
if (!allocator->TracksAllocationSizes()) {
return Status(error::INVALID_ARGUMENT,
"Tracking allocation is not enabled.");
}
allocator->GetStats(&stats);
(*allocator_stats_by_device)[device->name()] = stats;
}
return Status::OK();
}
} // namespace
TEST_F(SingleMachineTest, ReleaseMemoryAfterDestruction) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
// Add a variable and initializer.
Output a = ops::Variable(s.WithOpName("a"), TensorShape({128, 256}),
DataType::DT_FLOAT);
Output a_init =
ops::RandomNormal(s.WithOpName("a/init"), {128, 256}, DataType::DT_FLOAT);
Output a_init_assign = ops::Assign(s.WithOpName("a/init/assign"), a, a_init);
// Add a resource variable.
Output b =
ops::VarHandleOp(s.WithOpName("b"), DataType::DT_FLOAT, {256, 512});
Output b_read =
ops::ReadVariableOp(s.WithOpName("b/read"), b, DataType::DT_FLOAT);
Output b_init =
ops::RandomNormal(s.WithOpName("b/init"), {256, 512}, DataType::DT_FLOAT);
auto b_init_assign =
ops::AssignVariableOp(s.WithOpName("b/init/assign"), b, b_init);
// Add a queue.
ops::FIFOQueue queue(s.WithOpName("queue"), {DataType::DT_STRING});
Output some_string =
ops::Const(s.WithOpName("some_string"), string("nothing"));
ops::QueueEnqueue enqueue(s.WithOpName("enqueue"), queue, {some_string});
ops::QueueDequeue dequeue(s.WithOpName("dequeue"), queue,
{DataType::DT_STRING});
// Add a IdentityReader.
ops::IdentityReader reader(s.WithOpName("identity_reader"));
ops::ReaderRead read(s.WithOpName("read_from_queue"), reader, queue);
Output var_mul = ops::MatMul(s.WithOpName("var_matmul"), a, b_read);
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
QueueRunnerDef queue_runner;
queue_runner.set_queue_name("queue");
*queue_runner.add_enqueue_op_name() = "enqueue";
item.queue_runners.push_back(queue_runner);
item.init_ops.push_back("a/init/assign");
item.init_ops.push_back("b/init/assign");
item.fetch.push_back("var_matmul");
item.fetch.push_back("dequeue");
// Run the graph
TF_CHECK_OK(cluster_->Initialize(item));
EnableCPUAllocatorStats(true);
SessionOptions options =
GetSessionOption(3 /* cpu cores */, 0 /* num gpus */);
std::unordered_map<string, AllocatorStats> device_memory_before;
TF_CHECK_OK(GetDeviceMemoryStats(options, &device_memory_before));
EXPECT_EQ(device_memory_before.size(), 1);
RunMetadata metadata;
TF_CHECK_OK(cluster_->Run(item.graph, item.feed, item.fetch, &metadata));
// Check there is memory that is not released.
std::unordered_map<string, AllocatorStats> device_memory;
TF_CHECK_OK(GetDeviceMemoryStats(options, &device_memory));
EXPECT_EQ(device_memory.size(), 1);
EXPECT_GT(device_memory.begin()->second.bytes_in_use, 0);
// Reset cluster_ would release all memory.
cluster_.reset();
std::unordered_map<string, AllocatorStats> device_memory_after;
TF_CHECK_OK(GetDeviceMemoryStats(options, &device_memory_after));
// Check memory used by resources are released after cluster destruction.
EXPECT_EQ(device_memory_before.size(), 1);
EXPECT_EQ(device_memory_after.size(), 1);
EXPECT_EQ(device_memory_before.begin()->second.bytes_in_use, 0);
EXPECT_EQ(device_memory_after.begin()->second.bytes_in_use, 0);
}
#endif
} // namespace
} // namespace grappler
} // namespace tensorflow

View File

@ -82,10 +82,6 @@ Status OptimizeGraph(const GraphDef& graph_def, GraphDef* output_graph_def,
// Inline all functions.
GraphDef inlined_graph_def(graph_def);
// Populate default attrs to the NodeDefs in the GraphDef, which is required
// by inlining code.
TF_RETURN_IF_ERROR(
AddDefaultAttrsToGraphDef(&inlined_graph_def, *OpRegistry::Global(), 0));
for (int i = 0; i < inlined_graph_def.library().function().size(); i++) {
FunctionDef* fdef =
@ -122,6 +118,10 @@ Status OptimizeGraph(const GraphDef& graph_def, GraphDef* output_graph_def,
graph_ctor_opts.allow_internal_ops = true;
graph_ctor_opts.expect_device_spec = false;
std::unique_ptr<Graph> graphptr(new Graph(function_library));
// Populate default attrs to the NodeDefs in the GraphDef.
TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&inlined_graph_def,
*graphptr->op_registry(), 0));
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(graph_ctor_opts, inlined_graph_def,
graphptr.get()));

View File

@ -48,9 +48,17 @@ GraphDef CreateGraphDef(int num_stages, int width, int tensor_size,
for (int i = 0; i < num_stages; i++) {
std::vector<Output> this_stage;
for (int j = 0; j < width; j++) {
Output combine = AddN(
s.WithDevice(device_names[use_multiple_devices ? j : 0]), last_stage);
this_stage.push_back(combine);
if (last_stage.size() == 1) {
Output unary_op =
Square(s.WithDevice(device_names[use_multiple_devices ? j : 0]),
last_stage[0]);
this_stage.push_back(unary_op);
} else {
Output combine =
AddN(s.WithDevice(device_names[use_multiple_devices ? j : 0]),
last_stage);
this_stage.push_back(combine);
}
}
last_stage = this_stage;
}

View File

@ -18,6 +18,11 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
bool IsAddN(const NodeDef& node) {
const auto op = node.op();
return op == "AddN";
}
bool IsConcat(const NodeDef& node) {
const auto op = node.op();
return op == "Concat" || op == "ConcatV2";

View File

@ -21,6 +21,7 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
bool IsAddN(const NodeDef& node);
bool IsConcat(const NodeDef& node);
bool IsConstant(const NodeDef& node);
bool IsDequeueOp(const NodeDef& node);

View File

@ -26,6 +26,29 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
int NumNonControlInputs(const NodeDef& node) {
int num_inputs = node.input_size();
for (int i = 0; i < node.input_size(); ++i) {
if (!node.input(i).empty() && node.input(i)[0] == '^') {
num_inputs--;
}
}
return num_inputs;
}
bool IsTrivialOp(const NodeDef& node) {
// Remove the stop gradient nodes since they serve no purpose once the graph
// is built. Also remove Identity ops.
if (IsStopGradient(node) || IsIdentity(node)) {
return true;
}
if (IsAddN(node) && NumNonControlInputs(node) <= 1) {
return true;
}
return false;
}
Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* pruned_graph) {
GraphRewriter rewriter(item);
@ -43,9 +66,7 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item,
std::unordered_set<const NodeDef*> nodes_to_delete;
for (auto& node : item.graph.node()) {
// Remove the stop gradient nodes since they serve no purpose once the graph
// is built. Also remove Identity ops.
if (!IsStopGradient(node) && !IsIdentity(node)) {
if (!IsTrivialOp(node)) {
continue;
}
// Don't remove nodes that must be preserved.

View File

@ -57,10 +57,10 @@ TEST_F(ModelPrunerTest, StopGradientPruning) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
Output b = ops::AddN(s.WithOpName("b"), {a});
Output b = ops::Sqrt(s.WithOpName("b"), {a});
Output c = ops::StopGradient(s.WithOpName("c"), b);
Output d = ops::StopGradient(s.WithOpName("d"), c);
Output e = ops::AddN(s.WithOpName("e"), {d});
Output e = ops::Sqrt(s.WithOpName("e"), {d});
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
@ -93,10 +93,10 @@ TEST_F(ModelPrunerTest, IdentityPruning) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
Output b = ops::AddN(s.WithOpName("b"), {a});
Output b = ops::Sqrt(s.WithOpName("b"), {a});
Output c = ops::Identity(s.WithOpName("c"), b);
Output d = ops::Identity(s.WithOpName("d"), c);
Output e = ops::AddN(s.WithOpName("e"), {d});
Output e = ops::Sqrt(s.WithOpName("e"), {d});
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
@ -126,15 +126,53 @@ TEST_F(ModelPrunerTest, IdentityPruning) {
EXPECT_EQ(NodeName(b.name()), new_c.input(0));
}
TEST_F(ModelPrunerTest, PruningSkipsCtrlDependencies) {
TEST_F(ModelPrunerTest, NoOpPruning) {
// Build a simple graph with a few trivially prunable ops.
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
Output b = ops::AddN(s.WithOpName("b"), {a});
Output c = ops::AddN(s.WithOpName("c"), {b});
Output d = ops::AddN(s.WithOpName("d").WithControlDependencies(b), {c});
Output e = ops::AddN(s.WithOpName("e"), {d});
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
ModelPruner pruner;
GraphDef output;
Status status = pruner.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
EXPECT_EQ(5, output.node_size());
const NodeDef& new_a = output.node(0);
EXPECT_EQ(NodeName(a.name()), new_a.name());
const NodeDef& new_b = output.node(1);
EXPECT_EQ(NodeName(b.name()), new_b.name());
const NodeDef& new_c = output.node(2);
EXPECT_EQ(NodeName(c.name()), new_c.name());
const NodeDef& new_d = output.node(3);
EXPECT_EQ(NodeName(d.name()), new_d.name());
const NodeDef& new_e = output.node(4);
EXPECT_EQ(NodeName(e.name()), new_e.name());
EXPECT_EQ(1, new_e.input_size());
EXPECT_EQ(NodeName(d.name()), new_e.input(0));
EXPECT_EQ(2, new_d.input_size());
EXPECT_EQ(NodeName(b.name()), new_d.input(0));
EXPECT_EQ(1, new_c.input_size());
EXPECT_EQ(NodeName(b.name()), new_c.input(0));
}
TEST_F(ModelPrunerTest, PruningSkipsCtrlDependencies) {
// Build a simple graph with a few trivially prunable ops.
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
Output b = ops::Sqrt(s.WithOpName("b"), {a});
Output c = ops::Identity(s.WithOpName("c"), b);
Output d = ops::Identity(s.WithOpName("d"), c);
Output e = ops::AddN(s.WithOpName("e").WithControlDependencies(c), {d});
Output e = ops::Sqrt(s.WithOpName("e").WithControlDependencies(c), {d});
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
@ -166,11 +204,11 @@ TEST_F(ModelPrunerTest, PruningPerservesCtrlDependencies) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
Output b = ops::AddN(s.WithOpName("b"), {a});
Output c = ops::AddN(s.WithOpName("c"), {a});
Output b = ops::Sqrt(s.WithOpName("b"), {a});
Output c = ops::Sqrt(s.WithOpName("c"), {a});
Output d = ops::Identity(s.WithOpName("d"), c);
Output e = ops::Identity(s.WithOpName("e"), d);
Output f = ops::AddN(s.WithOpName("f"), {e});
Output f = ops::Sqrt(s.WithOpName("f"), {e});
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
@ -216,7 +254,7 @@ TEST_F(ModelPrunerTest, PruningPerservesFetch) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
Output b = ops::AddN(s.WithOpName("b"), {a});
Output b = ops::Sqrt(s.WithOpName("b"), {a});
Output c = ops::Identity(s.WithOpName("c"), b);
GrapplerItem item;
@ -243,13 +281,13 @@ TEST_F(ModelPrunerTest, PruningPerservesCrossDeviceIdentity) {
// Node i1 should be preserved.
Output i1 = ops::Identity(s.WithOpName("i1").WithDevice("/gpu:0"), c);
Output a1 = ops::AddN(s.WithOpName("a1").WithDevice("/gpu:0"), {i1});
Output a2 = ops::AddN(s.WithOpName("a2").WithDevice("/gpu:0"), {i1});
Output a1 = ops::Sqrt(s.WithOpName("a1").WithDevice("/gpu:0"), {i1});
Output a2 = ops::Sqrt(s.WithOpName("a2").WithDevice("/gpu:0"), {i1});
// Node i2 should be pruned since it resides on the sender's device.
Output i2 = ops::Identity(s.WithOpName("i2").WithDevice("/cpu:0"), c);
Output a3 = ops::AddN(s.WithOpName("a3").WithDevice("/gpu:0"), {i2});
Output a4 = ops::AddN(s.WithOpName("a4").WithDevice("/gpu:0"), {i2});
Output a3 = ops::Sqrt(s.WithOpName("a3").WithDevice("/gpu:0"), {i2});
Output a4 = ops::Sqrt(s.WithOpName("a4").WithDevice("/gpu:0"), {i2});
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));

View File

@ -237,6 +237,8 @@ class SymbolicGradientOp : public AsyncOpKernel {
FunctionLibraryRuntime::Options opts;
opts.step_id = ctx->step_id();
opts.rendezvous = ctx->rendezvous();
opts.cancellation_manager = ctx->cancellation_manager();
opts.runner = ctx->runner();
std::vector<Tensor> args;
args.reserve(ctx->num_inputs());

View File

@ -124,6 +124,20 @@ class MutableHashTableOfScalars final : public LookupInterface {
TensorShape value_shape() const override { return TensorShape(); }
int64 MemoryUsed() const override {
int64 ret = 0;
mutex_lock l(mu_);
for (unsigned i = 0; i < table_.bucket_count(); ++i) {
size_t bucket_size = table_.bucket_size(i);
if (bucket_size == 0) {
ret++;
} else {
ret += bucket_size;
}
}
return sizeof(MutableHashTableOfScalars) + ret;
}
private:
// TODO(andreasst): consider using a read/write lock or a concurrent map
mutable mutex mu_;
@ -239,6 +253,20 @@ class MutableHashTableOfTensors final : public LookupInterface {
TensorShape value_shape() const override { return value_shape_; }
int64 MemoryUsed() const override {
int64 ret = 0;
mutex_lock l(mu_);
for (unsigned i = 0; i < table_.bucket_count(); ++i) {
size_t bucket_size = table_.bucket_size(i);
if (bucket_size == 0) {
ret++;
} else {
ret += bucket_size;
}
}
return sizeof(MutableHashTableOfTensors) + ret;
}
private:
TensorShape value_shape_;
// TODO(andreasst): consider using a read/write lock or a concurrent map
@ -467,6 +495,12 @@ class MutableDenseHashTable final : public LookupInterface {
TensorShape value_shape() const override { return value_shape_; }
int64 MemoryUsed() const override {
mutex_lock l(mu_);
return sizeof(MutableDenseHashTable) + key_buckets_.AllocatedBytes() +
value_buckets_.AllocatedBytes() + empty_key_.AllocatedBytes();
}
private:
Status DoInsert(OpKernelContext* ctx, const Tensor& key, const Tensor& value,
bool ignore_empty_key) EXCLUSIVE_LOCKS_REQUIRED(mu_) {

View File

@ -39,6 +39,19 @@ static void GetRendezvousKey(const string& key_prefix,
frame_iter.iter_id);
}
static FrameAndIter GetFrameAndIter(OpKernelContext* ctx,
bool hostmem_sendrecv) {
if (hostmem_sendrecv && ctx->call_frame() != nullptr) {
// Host memory send/recv pairs are added by
// common_runtime/memory_types.cc. When the pair of nodes are
// added inside a function, we need to use the function call frame
// to formulate the unique rendezvous key.
return FrameAndIter(reinterpret_cast<uint64>(ctx->call_frame()), 0);
} else {
return ctx->frame_iter();
}
}
SendOp::SendOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
string send_device;
OP_REQUIRES_OK(ctx, ctx->GetAttr("send_device", &send_device));
@ -56,6 +69,9 @@ SendOp::SendOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
// proactively cache the rendezvous key for the top-level.
GetRendezvousKey(key_prefix_, {0, 0}, &parsed_key_.buf_);
OP_REQUIRES_OK(ctx, Rendezvous::ParseKey(parsed_key_.buf_, &parsed_key_));
if (!ctx->GetAttr("_hostmem_sendrecv", &hostmem_sendrecv_).ok()) {
hostmem_sendrecv_ = false;
}
}
void SendOp::Compute(OpKernelContext* ctx) {
@ -71,7 +87,8 @@ void SendOp::Compute(OpKernelContext* ctx) {
args.device_context = ctx->op_device_context();
args.alloc_attrs = ctx->input_alloc_attr(0);
if (ctx->frame_iter() == FrameAndIter(0, 0)) {
FrameAndIter frame_iter = GetFrameAndIter(ctx, hostmem_sendrecv_);
if (frame_iter == FrameAndIter(0, 0)) {
// Use the cached rendezvous key.
VLOG(2) << "Send " << parsed_key_.buf_;
OP_REQUIRES_OK(ctx,
@ -79,7 +96,7 @@ void SendOp::Compute(OpKernelContext* ctx) {
ctx->is_input_dead()));
} else {
Rendezvous::ParsedKey in_loop_parsed;
GetRendezvousKey(key_prefix_, ctx->frame_iter(), &in_loop_parsed.buf_);
GetRendezvousKey(key_prefix_, frame_iter, &in_loop_parsed.buf_);
VLOG(2) << "Send " << in_loop_parsed.buf_;
OP_REQUIRES_OK(ctx,
Rendezvous::ParseKey(in_loop_parsed.buf_, &in_loop_parsed));
@ -120,6 +137,9 @@ RecvOp::RecvOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
// proactively cache the rendezvous key for the top-level.
GetRendezvousKey(key_prefix_, {0, 0}, &parsed_key_.buf_);
OP_REQUIRES_OK(ctx, Rendezvous::ParseKey(parsed_key_.buf_, &parsed_key_));
if (!ctx->GetAttr("_hostmem_sendrecv", &hostmem_sendrecv_).ok()) {
hostmem_sendrecv_ = false;
}
}
void RecvOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
@ -151,12 +171,13 @@ void RecvOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
},
std::move(done), _1, _2, _3, _4, _5);
if (ctx->frame_iter() == FrameAndIter(0, 0)) {
FrameAndIter frame_iter = GetFrameAndIter(ctx, hostmem_sendrecv_);
if (frame_iter == FrameAndIter(0, 0)) {
VLOG(2) << "Recv " << parsed_key_.buf_;
ctx->rendezvous()->RecvAsync(parsed_key_, args, std::move(done_cb));
} else {
Rendezvous::ParsedKey in_loop_parsed;
GetRendezvousKey(key_prefix_, ctx->frame_iter(), &in_loop_parsed.buf_);
GetRendezvousKey(key_prefix_, frame_iter, &in_loop_parsed.buf_);
VLOG(2) << "Recv " << in_loop_parsed.buf_;
OP_REQUIRES_OK_ASYNC(
ctx, Rendezvous::ParseKey(in_loop_parsed.buf_, &in_loop_parsed), done);

View File

@ -29,6 +29,7 @@ class SendOp : public OpKernel {
private:
string key_prefix_;
Rendezvous::ParsedKey parsed_key_;
bool hostmem_sendrecv_;
TF_DISALLOW_COPY_AND_ASSIGN(SendOp);
};
@ -41,6 +42,7 @@ class RecvOp : public AsyncOpKernel {
private:
string key_prefix_;
Rendezvous::ParsedKey parsed_key_;
bool hostmem_sendrecv_;
TF_DISALLOW_COPY_AND_ASSIGN(RecvOp);
};

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
class DeviceMgr;
/// \brief A Session instance lets a caller drive a TensorFlow graph
/// computation.
@ -177,12 +178,24 @@ class Session {
/// *response. This API is optional. If it is unimplemented, Status will
/// return a corresponding error message, and *response will be unmodified.
virtual Status ListDevices(std::vector<DeviceAttributes>* response) = 0;
/// \brief Closes this session.
///
/// Closing a session releases the resources used by this session
/// on the TensorFlow runtime (specified during session creation by
/// the `SessionOptions::target` field).
virtual Status Close() = 0;
// NOTE(ashankar): As of July 2017, this is was a method added to
// faciliate some experimentation. Reconsider/re-evaluate after
// September 2017.
//
// Sets `*output` to the `DeviceMgr` that owns accessible devices in the
// address-space of the caller.
virtual Status LocalDeviceManager(const DeviceMgr** output) {
return errors::Unimplemented(
"LocalDeviceManager is not supported for this session.");
}
};
/// \brief Create a new session with the given options.

View File

@ -0,0 +1,153 @@
# Estimators
This document introduces **Estimators**--a high-level TensorFlow API that
greatly simplifies machine learning programming. Estimators encapsulate
the following actions:
* training
* evaluation
* prediction
* export for serving
You may either use the pre-made Estimators we provide or write your
own custom Estimators. All Estimators--whether pre-made or custom--are
classes based on the `tf.estimator.Estimator` class.
Note: TensorFlow also provides an Estimator class at
`tf.contrib.learn.Estimator`, which you should not use.</aside>
## Advantages of Estimators
Estimators provide the following benefits:
* You can run Estimators-based models on a local host or on a
distributed multi-server environment without changing your model.
Furthermore, you can run Estimators-based models on CPUs, GPUs,
or TPUs without recoding your model.
* Estimators simplify sharing implementations between model developers.
* You can develop a state of the art model with high-level intuitive code,
In short, it is generally much easier to create models with Estimators
than with the low-level TensorFlow APIs.
* Estimators are themselves built on tf.layers, which
simplifies customization.
* Estimators build the graph for you. In other words, you don't have to
build the graph.
* Estimators provide a safe distributed training loop that controls how and
when to:
* build the graph
* initialize variables
* start queues
* handle exceptions
* create checkpoint files and recover from failures
* save summaries for TensorBoard
When writing an application with Estimators, you must separate the data input
pipeline from the model. This separation simplifies experiments with
different data sets.
## Pre-made Estimators
Pre-made Estimators enable you to work at a much higher conceptual level
than the base TensorFlow APIs. You no longer have to worry about creating
the computational graph or sessions since Estimators handle all
the "plumbing" for you. That is, pre-made Estimators create and manage
`Graph` and `Session` objects for you. Furthermore, pre-made Estimators
let you experiment with different model architectures by making only minimal
code changes. `DNNClassifier`, for example, is a pre-made Estimator class that
trains classification models through dense, feed-forward neural networks.
### Structure of a pre-made Estimators program
A TensorFlow program relying on a pre-made Estimator typically consists
of the following four steps:
1. **Write one or more dataset importing functions.** For example, you might
create one function to import the training set and another function to
import the test set. Each dataset importing function must return two
objects:
* a dictionary in which the keys are feature column names and the
values are Tensors (or SparseTensors) containing the corresponding
feature data
* a Tensor containing one or more labels
For example, the following code illustrates the basic skeleton for
an input function:
def input_fn(dataset):
... # manipulate dataset, extracting feature names and the label
return feature_dict, label
See @{$datasets$Using the `Dataset` API for TensorFlow Input Pipelines}
for full details.)
2. **Define the feature columns.** Each @{tf.feature_column}
identifies a feature name, its type, and any input pre-processing.
For example, the following snippet creates three feature
columns that hold integer or floating-point data. The first two
feature columns simply identify the feature's name and type. The
third feature column also specifies a lambda the program will invoke
to scale the raw data:
# Define three numeric feature columns.
population = tf.feature_column.numeric_column('population')
crime_rate = tf.feature_column.numeric_column('crime_rate')
median_education = tf.feature_column.numeric_column('median_education',
normalizer_fn='lambda x: x - global_education_mean')
3. **Instantiate the relevant pre-made Estimator.** For example, here's
a sample instantiation of a pre-made Estimator named `LinearClassifier`:
# Instantiate an estimator, passing the feature columns.
estimator = tf.estimator.Estimator.LinearClassifier(
feature_columns=[population, crime_rate, median_education],
)
4. **Call a training, evaluation, or inference method.**
For example, all Estimators provide a `train` method, which trains a model.
# my_training_set is the function created in Step 1
estimator.train(input_fn=my_training_set, steps=2000)
### Benefits of pre-made Estimators
Pre-made Estimators encode best practices, providing the following benefits:
* Best practices for determining where different parts of the computational
graph should run, implementing strategies on a single machine or on a
cluster.
* Best practices for event (summary) writing and universally useful
summaries.
If you don't use pre-made Estimators, you must implement the preceding
features yourself.
## Custom Estimators
The heart of every Estimator--whether pre-made or custom--is its
**model function**, which is a method that builds graphs for training,
evaluation, and prediction. When you are using a pre-made Estimator,
someone else has already implemented the model function. When relying
on a custom Estimator, you must write the model function yourself. A
${$extend/estimators$companion document)
explains how to write the model function.
## Recommended workflow
We recommend the following workflow:
1. Assuming a suitable pre-made Estimator exists, use it to build your
first model and use its results to establish a baseline.
2. Build and test your overall pipeline, including the integrity and
reliability of your data with this pre-made Estimator.
3. If suitable alternative pre-made Estimators are available, run
experiments to determine which pre-made Estimator produces the
best results.
4. Possibly, further improve your model by building your own custom Estimator.

View File

@ -325,6 +325,25 @@ class FunctionTest(test.TestCase):
"assertion"):
_ = MyFn(100.0).eval()
def testWhileLoopCallsFunc(self):
with self.test_session(use_gpu=True) as sess:
@function.Defun(dtypes.float32)
def Times2(x):
constant_two = constant_op.constant(2, dtypes.int32)
two_on_gpu = math_ops.cast(constant_two, dtypes.float32)
return x * two_on_gpu
def Body(x):
x2 = Times2(x)
x2.set_shape([])
return x2
loop = control_flow_ops.while_loop(lambda x: x < 1e5, Body, [1.0])
ans = sess.run(loop)
self.assertAllClose(ans, 131072.)
def testControlFlowStrictness(self):
"""Inlined functions must not execute in a untaken control flow branch."""
@ -588,8 +607,8 @@ class FunctionTest(test.TestCase):
self.assertAllClose(vals[2], vals[3])
def testDeclare(self):
foo = function.Declare("Foo", [("x", dtypes.float32)],
[("y", dtypes.float32)])
foo = function.Declare("Foo", [("x", dtypes.float32)], [("y",
dtypes.float32)])
@function.Defun(dtypes.float32, func_name="Foo", out_names=["y"])
def FooImpl(x):
@ -607,8 +626,8 @@ class FunctionTest(test.TestCase):
self.assertAllClose(expected, y.eval(feed_dict={x: rand}))
def testDeclareUsedInDefun(self):
foo = function.Declare("Foo", [("x", dtypes.float32)],
[("y", dtypes.float32)])
foo = function.Declare("Foo", [("x", dtypes.float32)], [("y",
dtypes.float32)])
@function.Defun()
def Bar(x):
@ -630,8 +649,8 @@ class FunctionTest(test.TestCase):
self.assertAllClose(expected, y.eval(feed_dict={x: rand}))
def testDeclareTypeMistake(self):
foo = function.Declare("Foo", [("x", dtypes.float32)],
[("y", dtypes.float32)])
foo = function.Declare("Foo", [("x", dtypes.float32)], [("y",
dtypes.float32)])
@function.Defun(dtypes.float32, func_name="Foo", out_names=["y"])
def Foo(x):
@ -749,8 +768,9 @@ class FunctionTest(test.TestCase):
self.assertAllEqual(v1, 20.)
def testShapeFunction(self):
@function.Defun(dtypes.float32,
shape_func=lambda op: [op.inputs[0].get_shape()])
@function.Defun(
dtypes.float32, shape_func=lambda op: [op.inputs[0].get_shape()])
def Foo(x):
return x + 1.0
@ -767,11 +787,12 @@ class FunctionTest(test.TestCase):
self.assertAllEqual(y.get_shape().as_list(), [1, 1, 2, 3])
def testVariableReuse(self):
def LinearWithReuse(input_tensor, reuse=None):
size = input_tensor.shape.dims[1]
with variable_scope.variable_scope("linear", reuse=reuse):
w = variable_scope.get_variable("w", shape=[size, size],
dtype=input_tensor.dtype)
w = variable_scope.get_variable(
"w", shape=[size, size], dtype=input_tensor.dtype)
return math_ops.matmul(input_tensor, w)
@function.Defun(dtypes.float32)
@ -789,15 +810,19 @@ class FunctionTest(test.TestCase):
with session.Session() as sess:
sess.run(variables.global_variables_initializer())
output_val = sess.run(output_op,
feed_dict={input_op: np.random.rand(32, 100)})
output_val = sess.run(
output_op, feed_dict={input_op: np.random.rand(32, 100)})
self.assertEqual(output_val.shape, (32, 100))
def testFunctionCallInDifferentVariableScopes(self):
@function.Defun(dtypes.float32)
def Foo(inputs):
var = variable_scope.get_variable("var", shape=[10], dtype=dtypes.float32,
initializer=init_ops.ones_initializer())
var = variable_scope.get_variable(
"var",
shape=[10],
dtype=dtypes.float32,
initializer=init_ops.ones_initializer())
return inputs + var
input_op = array_ops.placeholder(shape=[10], dtype=dtypes.float32)
@ -813,8 +838,8 @@ class FunctionTest(test.TestCase):
with session.Session() as sess:
sess.run(variables.global_variables_initializer())
out1, out2 = sess.run([out1_op, out2_op],
feed_dict={input_op: np.linspace(1, 10, 10)})
out1, out2 = sess.run(
[out1_op, out2_op], feed_dict={input_op: np.linspace(1, 10, 10)})
self.assertAllEqual(out1, np.linspace(2, 11, 10))
self.assertAllEqual(out2, np.linspace(2, 11, 10))
@ -852,12 +877,15 @@ class FunctionsFromProtos(test.TestCase):
self.assertEqual(func.captured_inputs, new_func.captured_inputs)
def testBasic(self):
@function.Defun(dtypes.float32, dtypes.float32)
def Foo(x, y):
return x + y
self.expectFunctionsEqual(Foo)
def testGradFunc(self):
@function.Defun(dtypes.float32, dtypes.float32)
def G(x, dy):
return x * dy
@ -865,10 +893,12 @@ class FunctionsFromProtos(test.TestCase):
@function.Defun(dtypes.float32, grad_func=G)
def F(x):
return math_ops.exp(x) - math_ops.exp(-x)
self.expectFunctionsEqual(F, grad_func=G)
def testCapturedInputs(self):
c = constant_op.constant(10, dtypes.int64)
@function.Defun(dtypes.int64)
def Foo(x):
return x + c
@ -885,6 +915,7 @@ class FunctionsFromProtos(test.TestCase):
self.assertEqual(len(new_func.captured_inputs), 0)
def testNestedFunctions(self):
@function.Defun(dtypes.float32)
def Outer(x):
@ -958,6 +989,7 @@ class FunctionsFromProtos(test.TestCase):
self.assertEqual(len(function._from_library(library)), 0)
def testFromLibraryMissingFuncDef(self):
@function.Defun(dtypes.float32, dtypes.float32)
def G1(x, dy):
return x * dy
@ -989,6 +1021,7 @@ class FunctionsFromProtos(test.TestCase):
function._from_library(library)
def testFromLibraryCyclicGradFuncs(self):
@function.Defun(dtypes.float32)
def F1(x):
return math_ops.exp(x) - math_ops.exp(-x)
@ -1242,10 +1275,11 @@ class FunctionInlineControlTest(test.TestCase):
inp = np.random.uniform(-1, 1, [16, 1]).astype(np.float32)
run_metadata = config_pb2.RunMetadata()
with session.Session(graph=g, config=cfg) as sess:
ans = sess.run([y, dx], {x: inp},
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
ans = sess.run(
[y, dx], {x: inp},
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
print(ans[0], np.sum(ans[1]))
self.assertAllClose(ans[0], 255.971, rtol=1e-3)
self.assertAllClose(np.sum(ans[1]), 13.0408, rtol=1e-3)
@ -1275,8 +1309,7 @@ class ModuleFunctionTest(test.TestCase):
def testBasic(self):
with ops.Graph().as_default():
a, b, c, d, e = [
constant_op.constant(
[[_]], dtype=dtypes.float32) for _ in range(5)
constant_op.constant([[_]], dtype=dtypes.float32) for _ in range(5)
]
y = Linear(a, b, c)
z = Linear2(a, b, c, d, e)
@ -1295,7 +1328,8 @@ class VariableHoistingTest(test.TestCase):
initializer=init_ops.random_uniform_initializer(seed=312),
use_resource=use_resource)
b = variable_scope.get_variable(
"b", (64), initializer=init_ops.zeros_initializer(),
"b", (64),
initializer=init_ops.zeros_initializer(),
use_resource=use_resource),
return math_ops.sigmoid(math_ops.matmul(x, w) + b)
@ -1354,5 +1388,6 @@ class VariableHoistingTest(test.TestCase):
self._testSimpleModel(True, use_resource=True)
self._testSimpleModel(False, use_resource=True)
if __name__ == "__main__":
test.main()

View File

@ -148,7 +148,7 @@ def input_producer(input_tensor,
"""
with ops.name_scope(name, "input_producer", [input_tensor]):
input_tensor = ops.convert_to_tensor(input_tensor, name="input_tensor")
element_shape = input_tensor.get_shape()[1:].merge_with(element_shape)
element_shape = input_tensor.shape[1:].merge_with(element_shape)
if not element_shape.is_fully_defined():
raise ValueError("Either `input_tensor` must have a fully defined shape "
"or `element_shape` must be specified")
@ -168,7 +168,7 @@ def input_producer(input_tensor,
q, [enq], cancel_op=cancel_op))
if summary_name is not None:
summary.scalar(summary_name,
math_ops.cast(q.size(), dtypes.float32) * (1. / capacity))
math_ops.to_float(q.size()) * (1. / capacity))
return q
@ -465,7 +465,7 @@ def _store_sparse_tensors(tensor_list, enqueue_many, keep_input,
def _sparse_meta_data(t, storing_op, map_op):
if not isinstance(t, sparse_tensor.SparseTensor):
return _SparseMetaData(False, None, None)
rank = t.dense_shape.get_shape().with_rank(1)[0]
rank = t.dense_shape.shape.with_rank(1)[0]
if enqueue_many:
rank -= 1
# If a shared map_op was provided, use that. Otherwise use the name of
@ -492,7 +492,7 @@ def _store_sparse_tensors(tensor_list, enqueue_many, keep_input,
lambda: -1 * array_ops.ones(array_ops.shape(t)[0:1], dtypes.int64))
out_tensor.set_shape([None]) # necessary when t.ndims is unknown
return out_tensor
if keep_input.get_shape().ndims == 1:
if keep_input.shape.ndims == 1:
t = sparse_ops.sparse_retain(t, keep_input)
store_f = lambda t, name, _: _store_many_sparse(t, shared_name=name)
elif enqueue_many:
@ -577,13 +577,13 @@ def _validate_join(tensor_list_list):
def _validate_keep_input(keep_input, enqueue_many):
"""Validate `keep_input` argument to conditional batching functions."""
keep_input = ops.convert_to_tensor(keep_input)
if keep_input.get_shape().ndims is None:
if keep_input.shape.ndims is None:
raise ValueError(
"`keep_input` dimensions must be known at graph construction.")
if not enqueue_many and keep_input.get_shape().ndims == 1:
if not enqueue_many and keep_input.shape.ndims == 1:
raise ValueError(
"`keep_input` cannot be a vector when `enqueue_many=False`.")
if keep_input.get_shape().ndims > 1:
if keep_input.shape.ndims > 1:
raise ValueError("`keep_input` must be 0 or 1 dimensions.")
return keep_input
@ -632,18 +632,18 @@ def _shapes(tensor_list_list, shapes, enqueue_many):
for tl in tensor_list_list:
for i in xrange(len0):
if tl[i].get_shape().ndims is None:
if tl[i].shape.ndims is None:
raise ValueError("Cannot infer Tensor's rank: %s" % tl[i])
shapes = [_merge_shapes(
[tl[i].get_shape().as_list() for tl in tensor_list_list], enqueue_many)
[tl[i].shape.as_list() for tl in tensor_list_list], enqueue_many)
for i in xrange(len0)]
return shapes
def _select_which_to_enqueue(tensor_list, keep_input):
"""Select which examples to enqueue based on vector `keep_input`."""
select_i = math_ops.cast(keep_input, dtypes.int32)
select_i = math_ops.to_int32(keep_input)
tensor_list = [
data_flow_ops.dynamic_partition(x, select_i, num_partitions=2)[1]
for x in tensor_list]
@ -656,7 +656,7 @@ def _enqueue_join(queue, tensor_list_list, enqueue_many, keep_input):
enqueue_fn = queue.enqueue_many
else:
enqueue_fn = queue.enqueue
if keep_input.get_shape().ndims == 1:
if keep_input.shape.ndims == 1:
enqueue_ops = [enqueue_fn(_select_which_to_enqueue(x, keep_input))
for x in tensor_list_list]
else:
@ -673,7 +673,7 @@ def _enqueue(queue, tensor_list, threads, enqueue_many, keep_input):
enqueue_fn = queue.enqueue_many
else:
enqueue_fn = queue.enqueue
if keep_input.get_shape().ndims == 1:
if keep_input.shape.ndims == 1:
enqueue_ops = [
enqueue_fn(_select_which_to_enqueue(tensor_list, keep_input))] * threads
else:
@ -707,8 +707,7 @@ def _batch(tensors, batch_size, keep_input, num_threads=1, capacity=32,
capacity=capacity, dtypes=types, shapes=shapes, shared_name=shared_name)
_enqueue(queue, tensor_list, num_threads, enqueue_many, keep_input)
summary.scalar("fraction_of_%d_full" % capacity,
math_ops.cast(queue.size(), dtypes.float32) *
(1. / capacity))
math_ops.to_float(queue.size()) * (1. / capacity))
if allow_smaller_final_batch:
dequeued = queue.dequeue_up_to(batch_size, name=name)
@ -742,8 +741,7 @@ def _batch_join(tensors_list, batch_size, keep_input, capacity=32,
capacity=capacity, dtypes=types, shapes=shapes, shared_name=shared_name)
_enqueue_join(queue, tensor_list_list, enqueue_many, keep_input)
summary.scalar("fraction_of_%d_full" % capacity,
math_ops.cast(queue.size(), dtypes.float32) *
(1. / capacity))
math_ops.to_float(queue.size()) * (1. / capacity))
if allow_smaller_final_batch:
dequeued = queue.dequeue_up_to(batch_size, name=name)
@ -775,8 +773,8 @@ def _shuffle_batch(tensors, batch_size, capacity, min_after_dequeue,
capacity=capacity, min_after_dequeue=min_after_dequeue, seed=seed,
dtypes=types, shapes=shapes, shared_name=shared_name)
_enqueue(queue, tensor_list, num_threads, enqueue_many, keep_input)
full = (math_ops.cast(math_ops.maximum(0, queue.size() - min_after_dequeue),
dtypes.float32) *
full = (math_ops.to_float(
math_ops.maximum(0, queue.size() - min_after_dequeue)) *
(1. / (capacity - min_after_dequeue)))
# Note that name contains a '/' at the end so we intentionally do not place
# a '/' after %s below.
@ -812,8 +810,8 @@ def _shuffle_batch_join(tensors_list, batch_size, capacity,
capacity=capacity, min_after_dequeue=min_after_dequeue, seed=seed,
dtypes=types, shapes=shapes, shared_name=shared_name)
_enqueue_join(queue, tensor_list_list, enqueue_many, keep_input)
full = (math_ops.cast(math_ops.maximum(0, queue.size() - min_after_dequeue),
dtypes.float32) *
full = (math_ops.to_float(
math_ops.maximum(0, queue.size() - min_after_dequeue)) *
(1. / (capacity - min_after_dequeue)))
# Note that name contains a '/' at the end so we intentionally do not place
# a '/' after %s below.