Merge commit for internal changes

This commit is contained in:
Vijay Vasudevan 2017-05-03 18:44:10 -07:00
commit bbce813a58
94 changed files with 4186 additions and 1184 deletions

View File

@ -388,6 +388,16 @@ tf_gen_op_wrappers_cc(
visibility = ["//tensorflow:internal"], visibility = ["//tensorflow:internal"],
) )
tf_gen_op_wrappers_cc(
name = "functional_ops",
include_internal_ops = 1,
op_lib_names = [
"functional_ops",
],
pkg = "//tensorflow/core",
visibility = ["//tensorflow:internal"],
)
tf_gen_op_wrappers_cc( tf_gen_op_wrappers_cc(
name = "resource_variable_ops", name = "resource_variable_ops",
include_internal_ops = 1, include_internal_ops = 1,

View File

@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
@ -339,6 +340,14 @@ class LiteralUtil {
const Layout& layout, const Layout& layout,
Literal* literal); Literal* literal);
// Populates literal values by calling the generator function for every cell
// in the literal object.
template <typename NativeT>
static Status Populate(
Literal* literal,
const std::function<NativeT(tensorflow::gtl::ArraySlice<int64> indexes)>&
generator);
// Creates a Literal of the given dimensions with all elements set to the // Creates a Literal of the given dimensions with all elements set to the
// given value. // given value.
template <typename NativeT> template <typename NativeT>
@ -992,6 +1001,43 @@ template <typename NativeT>
literal); literal);
} }
template <typename NativeT>
/* static */ Status LiteralUtil::Populate(
Literal* literal,
const std::function<NativeT(tensorflow::gtl::ArraySlice<int64> indexes)>&
generator) {
const Shape& shape = literal->shape();
int64 rank = ShapeUtil::Rank(shape);
TF_RET_CHECK(shape.element_type() ==
primitive_util::NativeToPrimitiveType<NativeT>());
tensorflow::protobuf::RepeatedField<NativeT>* data =
GetMutableRepeatedField<NativeT>(literal);
if (rank > 0) {
std::vector<int64> base(rank, 0);
std::vector<int64> step(rank, 1);
std::vector<int64> minor_scan_indexes(rank, 0);
int64 minor_dimension = shape.layout().minor_to_major()[0];
int64 minor_dimension_size =
ShapeUtil::GetDimension(shape, minor_dimension);
step[minor_dimension] = minor_dimension_size;
auto init_function = [&](const std::vector<int64>& indexes) {
int64 index = LinearIndex(*literal, indexes);
std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin());
for (int64 i = 0; i < minor_dimension_size; ++i) {
minor_scan_indexes[minor_dimension] = i;
data->Set(index + i, generator(minor_scan_indexes));
}
return true;
};
ShapeUtil::ForEachIndex(shape, base, AsInt64Slice(shape.dimensions()), step,
init_function);
} else {
data->Set(0, generator({}));
}
return Status::OK();
}
template <typename NativeT> template <typename NativeT>
/* static */ void LiteralUtil::PopulateWithValue( /* static */ void LiteralUtil::PopulateWithValue(
NativeT value, tensorflow::gtl::ArraySlice<int64> dimensions, NativeT value, tensorflow::gtl::ArraySlice<int64> dimensions,

View File

@ -422,7 +422,7 @@ class ReferenceUtil {
static std::unique_ptr<Array2D<T1>> ApplyElementwise2D( static std::unique_ptr<Array2D<T1>> ApplyElementwise2D(
F&& f, const Array2D<T1>& array1, const Array2D<Ts>&... arrays) { F&& f, const Array2D<T1>& array1, const Array2D<Ts>&... arrays) {
AssertSameSize2D(array1, arrays...); AssertSameSize2D(array1, arrays...);
auto result = MakeUnique<Array2D<T1>>(array1.n1(), array1.n1()); auto result = MakeUnique<Array2D<T1>>(array1.n1(), array1.n2());
for (int64 i = 0; i < array1.n1(); ++i) { for (int64 i = 0; i < array1.n1(); ++i) {
for (int64 j = 0; j < array1.n2(); ++j) { for (int64 j = 0; j < array1.n2(); ++j) {
(*result)(i, j) = f(array1(i, j), arrays(i, j)...); (*result)(i, j) = f(array1(i, j), arrays(i, j)...);

View File

@ -80,8 +80,6 @@ cc_library(
":hlo_query", ":hlo_query",
"//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:util",
@ -666,8 +664,8 @@ cc_library(
], ],
deps = [ deps = [
":buffer_liveness", ":buffer_liveness",
":heap_simulator",
":hlo", ":hlo",
":hlo_ordering",
":logical_buffer", ":logical_buffer",
":tuple_points_to_analysis", ":tuple_points_to_analysis",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
@ -707,51 +705,38 @@ cc_test(
], ],
) )
cc_library(
name = "heap_simulator",
srcs = [
"heap_simulator.cc",
],
hdrs = [
"heap_simulator.h",
],
deps = [
":hlo",
":liveness_util",
":logical_buffer",
":tuple_points_to_analysis",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
],
)
cc_test( cc_test(
name = "heap_simulator_test", name = "heap_simulator_test",
srcs = ["heap_simulator_test.cc"], srcs = ["heap_simulator_test.cc"],
deps = [ deps = [
":heap_simulator",
":hlo", ":hlo",
":hlo_ordering",
":logical_buffer", ":logical_buffer",
":tuple_points_to_analysis", ":tuple_points_to_analysis",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/core:lib",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
], ],
) )
# The hlo_ordering library contains both hlo_ordering and heap_simulator because
# they are mutually dependent.
cc_library( cc_library(
name = "hlo_ordering", name = "hlo_ordering",
srcs = [ srcs = [
"heap_simulator.cc",
"hlo_ordering.cc", "hlo_ordering.cc",
], ],
hdrs = [ hdrs = [
"heap_simulator.h",
"hlo_ordering.h", "hlo_ordering.h",
], ],
deps = [ deps = [
":call_graph", ":call_graph",
":heap_simulator",
":hlo", ":hlo",
":liveness_util",
":logical_buffer", ":logical_buffer",
":tuple_points_to_analysis", ":tuple_points_to_analysis",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
@ -1436,6 +1421,7 @@ cc_test(
"//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
], ],

View File

@ -548,6 +548,8 @@ Status BufferAssigner::AssignBuffersForComputation(
const FlatSet<const HloInstruction*>* hlos_to_allocate, const FlatSet<const HloInstruction*>* hlos_to_allocate,
const FlatSet<const LogicalBuffer*>& colocated_buffers, const FlatSet<const LogicalBuffer*>& colocated_buffers,
const FlatSet<BufferAllocation::Index>& colocated_allocations, const FlatSet<BufferAllocation::Index>& colocated_allocations,
FlatMap<const HloComputation*, FlatSet<const LogicalBuffer*>>*
buffers_to_assign_sequentially,
BufferAssignment* assignment) { BufferAssignment* assignment) {
// Buffers are sorted and assigned to BufferAllocations in decreasing order of // Buffers are sorted and assigned to BufferAllocations in decreasing order of
// size. // size.
@ -578,9 +580,16 @@ Status BufferAssigner::AssignBuffersForComputation(
// If there is a sequential instruction ordering, we'll delay assignment of // If there is a sequential instruction ordering, we'll delay assignment of
// temp buffers until after the main assignment loop. // temp buffers until after the main assignment loop.
const BufferLiveness& liveness = assignment->liveness(); const BufferLiveness& liveness = assignment->liveness();
const std::vector<const HloInstruction*>* sequential_order = const bool has_sequential_order =
liveness.hlo_ordering().SequentialOrder(*computation); liveness.hlo_ordering().SequentialOrder(*computation) != nullptr;
FlatSet<const LogicalBuffer*> unassigned_temp_buffers; if (has_sequential_order && buffers_to_assign_sequentially != nullptr) {
// Every sequential computation must get an entry in the
// buffers_to_assign_sequentially map, even if we end up with an empty set
// of buffers. This ensures we can correctly determine whether to run
// whole-module heap simulation.
buffers_to_assign_sequentially->emplace(computation,
FlatSet<const LogicalBuffer*>());
}
// Sort the LogicalBuffers first by size. We assign the larger LogicalBuffers // Sort the LogicalBuffers first by size. We assign the larger LogicalBuffers
// first for simplicity. This means any previously created BufferAllocation is // first for simplicity. This means any previously created BufferAllocation is
@ -599,7 +608,7 @@ Status BufferAssigner::AssignBuffersForComputation(
// important reuse case where an elementwise instruction reuses one of its // important reuse case where an elementwise instruction reuses one of its
// operand's buffer. This improves locality. // operand's buffer. This improves locality.
std::sort(sorted_buffers.begin(), sorted_buffers.end(), std::sort(sorted_buffers.begin(), sorted_buffers.end(),
[this, sequential_order, &liveness, &post_order_position]( [this, has_sequential_order, &liveness, &post_order_position](
const LogicalBuffer* a, const LogicalBuffer* b) { const LogicalBuffer* a, const LogicalBuffer* b) {
// Primary sort is by decreasing buffer size. // Primary sort is by decreasing buffer size.
const int64 a_size = buffer_size_(*a); const int64 a_size = buffer_size_(*a);
@ -609,7 +618,7 @@ Status BufferAssigner::AssignBuffersForComputation(
} }
// Otherwise live out buffers come before others, if the // Otherwise live out buffers come before others, if the
// instructions are sequentially ordered. // instructions are sequentially ordered.
if (sequential_order != nullptr) { if (has_sequential_order) {
const bool a_live_out = liveness.MaybeLiveOut(*a); const bool a_live_out = liveness.MaybeLiveOut(*a);
const bool b_live_out = liveness.MaybeLiveOut(*b); const bool b_live_out = liveness.MaybeLiveOut(*b);
if (a_live_out != b_live_out) { if (a_live_out != b_live_out) {
@ -746,7 +755,7 @@ Status BufferAssigner::AssignBuffersForComputation(
} }
} }
if (!assignment->HasAllocation(*buffer) && sequential_order != nullptr && if (!assignment->HasAllocation(*buffer) && has_sequential_order &&
!liveness.MaybeLiveOut(*buffer)) { !liveness.MaybeLiveOut(*buffer)) {
// There is a sequential instruction ordering, so we delay assignment of // There is a sequential instruction ordering, so we delay assignment of
// temp buffers until after the loop. We do this right before we decide to // temp buffers until after the loop. We do this right before we decide to
@ -758,7 +767,7 @@ Status BufferAssigner::AssignBuffersForComputation(
// for the definition of temp buffers. // for the definition of temp buffers.
CHECK(!is_entry_parameter) << *buffer; CHECK(!is_entry_parameter) << *buffer;
CHECK(!is_thread_local) << *buffer; CHECK(!is_thread_local) << *buffer;
unassigned_temp_buffers.insert(buffer); (*buffers_to_assign_sequentially)[computation].insert(buffer);
VLOG(3) << "Delaying assignment of temp buffer: " << *buffer; VLOG(3) << "Delaying assignment of temp buffer: " << *buffer;
continue; continue;
} }
@ -772,27 +781,68 @@ Status BufferAssigner::AssignBuffersForComputation(
} }
} }
if (!unassigned_temp_buffers.empty()) {
TF_RETURN_IF_ERROR(AssignBuffersWithSequentialOrdering(
*sequential_order, unassigned_temp_buffers, *computation, assignment));
}
return Status::OK(); return Status::OK();
} }
Status BufferAssigner::AssignBuffersWithSequentialOrdering( Status BufferAssigner::AssignBuffersWithSequentialOrdering(
const std::vector<const HloInstruction*>& sequence, const FlatMap<const HloComputation*, FlatSet<const LogicalBuffer*>>&
const FlatSet<const LogicalBuffer*>& buffers_to_assign, buffers_to_assign_sequentially,
const HloComputation& computation, BufferAssignment* assignment) { bool run_whole_module_heap_simulation, BufferAssignment* assignment) {
// Run the sequence of instructions through the heap simulator. The heuristic // Run the sequence of instructions through the heap simulator. The heuristic
// that seems to give the best results is lazy-best-fit, with all runs of // that seems to give the best results is lazy-best-fit, with all runs of
// alloc / free calls sorted in decreasing size order. // alloc / free calls sorted in decreasing size order.
TF_ASSIGN_OR_RETURN( const HloOrdering& hlo_ordering = assignment->liveness().hlo_ordering();
HeapSimulator::Result result, if (run_whole_module_heap_simulation) {
HeapSimulator::Run(MakeUnique<DecreasingSizeRunsHeap>( // Run the heap simulation over the whole module. This reduces memory usage,
MakeUnique<LazyBestFitHeap>(alignment_)), // since buffers for kCall and kWhile sub-computations are only live for the
sequence, computation, // duration of their calling instructions.
assignment->points_to_analysis(), buffer_size_, VLOG(1) << "Running whole-module heap simulation";
&buffers_to_assign)); SequentialHloOrdering::HloModuleSequence module_sequence;
FlatSet<const LogicalBuffer*> all_buffers_to_assign;
for (const auto& pair : buffers_to_assign_sequentially) {
const HloComputation* computation = pair.first;
const FlatSet<const LogicalBuffer*>& buffers_to_assign = pair.second;
const std::vector<const HloInstruction*>* instruction_sequence =
hlo_ordering.SequentialOrder(*computation);
CHECK(instruction_sequence != nullptr) << computation->name();
module_sequence[computation] = *instruction_sequence;
all_buffers_to_assign.insert(buffers_to_assign.begin(),
buffers_to_assign.end());
}
TF_ASSIGN_OR_RETURN(
const HeapSimulator::Result result,
HeapSimulator::Run(MakeUnique<DecreasingSizeRunsHeap>(
MakeUnique<LazyBestFitHeap>(alignment_)),
assignment->module(), module_sequence,
assignment->points_to_analysis(), buffer_size_,
&all_buffers_to_assign));
AssignBuffersFromHeapSimulator(result, assignment);
} else {
// Run the heap-simulation on a per-computation basis. Buffers for
// sub-computations are assigned disjoint BufferAllocations, assuming the
// worst-case that they may all be live concurrently.
VLOG(1) << "Running per-computation heap simulation";
for (const auto& pair : buffers_to_assign_sequentially) {
const HloComputation* computation = pair.first;
const FlatSet<const LogicalBuffer*>& buffers_to_assign = pair.second;
const std::vector<const HloInstruction*>* instruction_sequence =
hlo_ordering.SequentialOrder(*computation);
CHECK(instruction_sequence != nullptr) << computation->name();
TF_ASSIGN_OR_RETURN(
const HeapSimulator::Result result,
HeapSimulator::Run(MakeUnique<DecreasingSizeRunsHeap>(
MakeUnique<LazyBestFitHeap>(alignment_)),
*computation, *instruction_sequence,
assignment->points_to_analysis(), buffer_size_,
&buffers_to_assign));
AssignBuffersFromHeapSimulator(result, assignment);
}
}
return Status::OK();
}
void BufferAssigner::AssignBuffersFromHeapSimulator(
const HeapSimulator::Result& result, BufferAssignment* assignment) {
if (assignment->stats_.preallocated_temp_fragmentation_bytes == -1) { if (assignment->stats_.preallocated_temp_fragmentation_bytes == -1) {
assignment->stats_.preallocated_temp_fragmentation_bytes = assignment->stats_.preallocated_temp_fragmentation_bytes =
result.fragmentation_size; result.fragmentation_size;
@ -801,8 +851,6 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
result.fragmentation_size; result.fragmentation_size;
} }
// Use the results of the heap simulator to create one allocation per
// computation, with LogicalBuffers packed to specific offsets.
BufferAllocation* allocation = assignment->NewEmptyAllocation( BufferAllocation* allocation = assignment->NewEmptyAllocation(
result.heap_size, /*is_thread_local=*/false, /*is_reusable=*/true); result.heap_size, /*is_thread_local=*/false, /*is_reusable=*/true);
for (const auto& buffer_chunk : result.chunk_map) { for (const auto& buffer_chunk : result.chunk_map) {
@ -810,7 +858,6 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
const HeapSimulator::Chunk& chunk = buffer_chunk.second; const HeapSimulator::Chunk& chunk = buffer_chunk.second;
assignment->AddAssignment(allocation, buffer, chunk.offset, chunk.size); assignment->AddAssignment(allocation, buffer, chunk.offset, chunk.size);
} }
return Status::OK();
} }
// Adds the 'colocated_set' of buffers to 'colocated_buffer_sets', maintaining // Adds the 'colocated_set' of buffers to 'colocated_buffer_sets', maintaining
@ -1108,8 +1155,6 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
TF_ASSIGN_OR_RETURN(std::unique_ptr<BufferLiveness> liveness, TF_ASSIGN_OR_RETURN(std::unique_ptr<BufferLiveness> liveness,
BufferLiveness::Run(module, std::move(hlo_ordering))); BufferLiveness::Run(module, std::move(hlo_ordering)));
std::vector<const HloComputation*> thread_local_computations;
std::vector<const HloComputation*> global_computations;
VLOG(1) << "Assigning buffers to module " << module->name(); VLOG(1) << "Assigning buffers to module " << module->name();
if (hlos_to_allocate != nullptr) { if (hlos_to_allocate != nullptr) {
VLOG(3) << "LogicalBuffer assignment restricted to hlos: "; VLOG(3) << "LogicalBuffer assignment restricted to hlos: ";
@ -1121,9 +1166,6 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
XLA_VLOG_LINES(3, liveness->ToString()); XLA_VLOG_LINES(3, liveness->ToString());
XLA_VLOG_LINES(3, liveness->points_to_analysis().ToString()); XLA_VLOG_LINES(3, liveness->points_to_analysis().ToString());
TF_RETURN_IF_ERROR(GatherComputationsByAllocationType(
module, &thread_local_computations, &global_computations));
// Set of HLO's to allocate if hlos_to_allocate is given. Passed as a set to // Set of HLO's to allocate if hlos_to_allocate is given. Passed as a set to
// AssignBuffersForComputation for fast membership testing. // AssignBuffersForComputation for fast membership testing.
std::unique_ptr<FlatSet<const HloInstruction*>> hlo_set; std::unique_ptr<FlatSet<const HloInstruction*>> hlo_set;
@ -1148,16 +1190,38 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
AssignColocatedBufferSets(colocated_buffer_sets, assignment.get(), AssignColocatedBufferSets(colocated_buffer_sets, assignment.get(),
&colocated_buffers, &colocated_allocations); &colocated_buffers, &colocated_allocations);
std::vector<const HloComputation*> thread_local_computations;
std::vector<const HloComputation*> global_computations;
TF_RETURN_IF_ERROR(GatherComputationsByAllocationType(
module, &thread_local_computations, &global_computations));
// First assign buffers for global computatations. Temporary buffers for
// sequential computations are collected in 'buffers_to_assign_sequentially'.
FlatMap<const HloComputation*, FlatSet<const LogicalBuffer*>>
buffers_to_assign_sequentially;
for (auto* computation : global_computations) { for (auto* computation : global_computations) {
TF_RETURN_IF_ERROR(AssignBuffersForComputation( TF_RETURN_IF_ERROR(AssignBuffersForComputation(
computation, /*is_thread_local=*/false, hlo_set.get(), computation, /*is_thread_local=*/false, hlo_set.get(),
colocated_buffers, colocated_allocations, assignment.get())); colocated_buffers, colocated_allocations,
&buffers_to_assign_sequentially, assignment.get()));
} }
// Assign buffers with sequential ordering, if any. If all global computations
// are sequential, we can run heap simuation on the whole module, which
// reduces memory usage.
const bool run_whole_module_heap_simulation =
buffers_to_assign_sequentially.size() == global_computations.size();
TF_RETURN_IF_ERROR(AssignBuffersWithSequentialOrdering(
buffers_to_assign_sequentially, run_whole_module_heap_simulation,
assignment.get()));
// Now assign buffers for thread-local computations. All LogicalBuffers get
// their own BufferAllocation.
for (auto* computation : thread_local_computations) { for (auto* computation : thread_local_computations) {
TF_RET_CHECK(computation != module->entry_computation()); TF_RET_CHECK(computation != module->entry_computation());
TF_RETURN_IF_ERROR(AssignBuffersForComputation( TF_RETURN_IF_ERROR(AssignBuffersForComputation(
computation, /*is_thread_local=*/true, hlo_set.get(), colocated_buffers, computation, /*is_thread_local=*/true, hlo_set.get(), colocated_buffers,
colocated_allocations, assignment.get())); colocated_allocations, /*buffers_to_assign_sequentially=*/nullptr,
assignment.get()));
} }
// Mark all buffers which may be live out of the entry computation as // Mark all buffers which may be live out of the entry computation as

View File

@ -23,6 +23,7 @@ limitations under the License.
#include <vector> #include <vector>
#include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/buffer_liveness.h"
#include "tensorflow/compiler/xla/service/heap_simulator.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module.h"
@ -354,6 +355,9 @@ class BufferAssignment {
void AddAssignment(BufferAllocation* allocation, const LogicalBuffer& buffer, void AddAssignment(BufferAllocation* allocation, const LogicalBuffer& buffer,
int64 offset, int64 size); int64 offset, int64 size);
// Returns the HloModule used to construct this assignment.
const HloModule& module() { return *module_; }
// Returns the BufferLiveness object used to construct this assignment. // Returns the BufferLiveness object used to construct this assignment.
const BufferLiveness& liveness() { return *liveness_; } const BufferLiveness& liveness() { return *liveness_; }
@ -427,14 +431,27 @@ class BufferAssigner {
const tensorflow::gtl::FlatSet<const LogicalBuffer*>& colocated_buffers, const tensorflow::gtl::FlatSet<const LogicalBuffer*>& colocated_buffers,
const tensorflow::gtl::FlatSet<BufferAllocation::Index>& const tensorflow::gtl::FlatSet<BufferAllocation::Index>&
colocated_allocations, colocated_allocations,
tensorflow::gtl::FlatMap<const HloComputation*,
tensorflow::gtl::FlatSet<const LogicalBuffer*>>*
buffers_to_assign_sequentially,
BufferAssignment* assignment); BufferAssignment* assignment);
// Assigns 'buffers_to_assign' assuming the HLO instructions will be executed // Assigns 'buffers_to_assign_sequentially' using heap simulation, assuming
// in the given 'sequential_order'. // the HLO instructions will be executed in the sequential order given by
// assignment->liveness().hlo_ordering().SequentialOrder. If
// 'run_whole_module_heap_simulation' is true, the heap simulation will be run
// assuming all global computations are sequentially ordered.
Status AssignBuffersWithSequentialOrdering( Status AssignBuffersWithSequentialOrdering(
const std::vector<const HloInstruction*>& sequential_order, const tensorflow::gtl::FlatMap<
const tensorflow::gtl::FlatSet<const LogicalBuffer*>& buffers_to_assign, const HloComputation*,
const HloComputation& computation, BufferAssignment* assignment); tensorflow::gtl::FlatSet<const LogicalBuffer*>>&
buffers_to_assign_sequentially,
bool run_whole_module_heap_simulation, BufferAssignment* assignment);
// Uses the results of the heap simulator to create a single allocation, with
// LogicalBuffers packed to specific offsets.
void AssignBuffersFromHeapSimulator(const HeapSimulator::Result& result,
BufferAssignment* assignment);
// Tries to assign the given instruction to the given buffer. Returns if the // Tries to assign the given instruction to the given buffer. Returns if the
// assignment was successful. // assignment was successful.
@ -477,8 +494,6 @@ class BufferAssigner {
const HloComputation& computation, const BufferLiveness& buffer_liveness, const HloComputation& computation, const BufferLiveness& buffer_liveness,
std::vector<ColocatedBufferSet>* colocated_buffer_sets); std::vector<ColocatedBufferSet>* colocated_buffer_sets);
const HloModule* module_;
// Function which returns the buffer size for a given logical buffer (shape). // Function which returns the buffer size for a given logical buffer (shape).
LogicalBuffer::SizeFunction buffer_size_; LogicalBuffer::SizeFunction buffer_size_;

View File

@ -24,6 +24,11 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
int64 operand_index) { int64 operand_index) {
HloInstruction* producer = consumer->mutable_operand(operand_index); HloInstruction* producer = consumer->mutable_operand(operand_index);
// Output fusion is not currently supported on CPUs.
if (producer->opcode() == HloOpcode::kFusion) {
return false;
}
// Condition for consumer: must be elementwise or a fusion op // Condition for consumer: must be elementwise or a fusion op
// (which necessarily only contains elementwise operations) // (which necessarily only contains elementwise operations)
if (!(consumer->opcode() == HloOpcode::kFusion || if (!(consumer->opcode() == HloOpcode::kFusion ||

View File

@ -46,6 +46,11 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
int64 operand_index) { int64 operand_index) {
HloInstruction* producer = consumer->mutable_operand(operand_index); HloInstruction* producer = consumer->mutable_operand(operand_index);
// Output fusion is not currently supported on GPUs.
if (producer->opcode() == HloOpcode::kFusion) {
return false;
}
// RNG operations are not currently parallel-friendly on GPU. // RNG operations are not currently parallel-friendly on GPU.
if (producer->opcode() == HloOpcode::kRng) { if (producer->opcode() == HloOpcode::kRng) {
return false; return false;

View File

@ -53,12 +53,44 @@ std::vector<const LogicalBuffer*> UniqueOperandSourceBuffers(
/*static*/ /*static*/
StatusOr<HeapSimulator::Result> HeapSimulator::Run( StatusOr<HeapSimulator::Result> HeapSimulator::Run(
std::unique_ptr<HeapAlgorithm> algorithm, std::unique_ptr<HeapAlgorithm> algorithm, const HloModule& module,
const std::vector<const HloInstruction*>& instruction_sequence, const SequentialHloOrdering::HloModuleSequence& module_sequence,
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis, const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_fn, const LogicalBuffer::SizeFunction& size_fn,
const FlatSet<const LogicalBuffer*>* buffers_to_assign) { const FlatSet<const LogicalBuffer*>* buffers_to_assign) {
HeapSimulator heap(std::move(algorithm), size_fn, buffers_to_assign);
const HloComputation* entry_computation = module.entry_computation();
const std::vector<const HloInstruction*>& instruction_sequence =
FindOrDie(module_sequence, entry_computation);
TF_RETURN_IF_ERROR(heap.RunComputation(*entry_computation,
instruction_sequence,
points_to_analysis, &module_sequence));
return heap.Finish();
}
/*static*/
StatusOr<HeapSimulator::Result> HeapSimulator::Run(
std::unique_ptr<HeapAlgorithm> algorithm, const HloComputation& computation,
const std::vector<const HloInstruction*>& instruction_sequence,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_fn,
const FlatSet<const LogicalBuffer*>* buffers_to_assign) {
HeapSimulator heap(std::move(algorithm), size_fn, buffers_to_assign);
TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence,
points_to_analysis,
/*module_sequence=*/nullptr));
return heap.Finish();
}
// Runs a heap simulation for the given 'computation', assuming the given
// 'instruction_sequence'. If 'module_sequence' is non-null, it is used to find
// kCall and kWhile sub-computations, and the heap simulation for those
// sub-computations will be run recursively.
Status HeapSimulator::RunComputation(
const HloComputation& computation,
const std::vector<const HloInstruction*>& instruction_sequence,
const TuplePointsToAnalysis& points_to_analysis,
const SequentialHloOrdering::HloModuleSequence* module_sequence) {
// The goal here is to minimize memory usage, assuming the given sequential // The goal here is to minimize memory usage, assuming the given sequential
// ordering of instructions. The strategy is to walk through the instruction // ordering of instructions. The strategy is to walk through the instruction
// sequence, calling Alloc and Free on the underlying heap algorithm. The // sequence, calling Alloc and Free on the underlying heap algorithm. The
@ -67,7 +99,6 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run(
// 'live_buffers' tracks the liveness of each buffer that we assign, by // 'live_buffers' tracks the liveness of each buffer that we assign, by
// associating it with a set of HloInstructions that need to be visited. When // associating it with a set of HloInstructions that need to be visited. When
// the set becomes empty, the buffer is no longer used, and can be freed. // the set becomes empty, the buffer is no longer used, and can be freed.
HeapSimulator heap(std::move(algorithm), size_fn, buffers_to_assign);
FlatMap<const LogicalBuffer*, FlatSet<const HloInstruction*>> live_buffers; FlatMap<const LogicalBuffer*, FlatSet<const HloInstruction*>> live_buffers;
const HloInstruction* root = computation.root_instruction(); const HloInstruction* root = computation.root_instruction();
@ -90,7 +121,7 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run(
// lifetime of buffers that aren't already connected by a data dependency. // lifetime of buffers that aren't already connected by a data dependency.
std::vector<const LogicalBuffer*> dead_buffers_to_free; std::vector<const LogicalBuffer*> dead_buffers_to_free;
for (const LogicalBuffer* buffer : buffers_defined_by_instruction) { for (const LogicalBuffer* buffer : buffers_defined_by_instruction) {
if (heap.IgnoreBuffer(buffer)) { if (IgnoreBuffer(buffer)) {
continue; continue;
} }
for (const BufferAlias& alias : for (const BufferAlias& alias :
@ -127,7 +158,7 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run(
std::vector<const LogicalBuffer*> operand_buffers_to_free; std::vector<const LogicalBuffer*> operand_buffers_to_free;
for (const LogicalBuffer* operand_buffer : for (const LogicalBuffer* operand_buffer :
UniqueOperandSourceBuffers(instruction, points_to_analysis)) { UniqueOperandSourceBuffers(instruction, points_to_analysis)) {
if (heap.IgnoreBuffer(operand_buffer)) { if (IgnoreBuffer(operand_buffer)) {
continue; continue;
} }
live_buffers[operand_buffer].erase(instruction); live_buffers[operand_buffer].erase(instruction);
@ -142,10 +173,10 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run(
// happen before dead or operand buffers are freed; the instruction reads // happen before dead or operand buffers are freed; the instruction reads
// the operand buffers to produce its output. // the operand buffers to produce its output.
// //
// INVARIANT: Either heap.Alloc or heap.ShareBuffer will be called for each // INVARIANT: Either Alloc or ShareBuffer will be called for each buffer
// buffer that we should assign. // that we should assign.
for (const LogicalBuffer* buffer : buffers_defined_by_instruction) { for (const LogicalBuffer* buffer : buffers_defined_by_instruction) {
if (heap.IgnoreBuffer(buffer)) { if (IgnoreBuffer(buffer)) {
continue; continue;
} }
@ -159,24 +190,50 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run(
CanShareOperandBufferWithUser( CanShareOperandBufferWithUser(
operand_buffer->instruction(), operand_buffer->index(), operand_buffer->instruction(), operand_buffer->index(),
buffer->instruction(), buffer->index(), points_to_analysis)) { buffer->instruction(), buffer->index(), points_to_analysis)) {
heap.ShareBuffer(buffer, operand_buffer); ShareBuffer(buffer, operand_buffer);
shared = true; shared = true;
break; break;
} }
} }
if (!shared) { if (!shared) {
heap.Alloc(buffer); Alloc(buffer);
} }
} }
// If the whole module is sequential, we can save memory by running the
// heap-simulation for sub-computations inline. E.g. the buffers for the
// condition and body of a kWhile instruction are only live for the duration
// of the instruction itself.
//
// The order that the sub-computations are simulated does not affect
// correctness; since the whole module is sequential, we know that the
// sub-computations will never be run concurrently.
if (module_sequence != nullptr) {
if (instruction->opcode() == HloOpcode::kCall ||
instruction->opcode() == HloOpcode::kWhile) {
for (const HloComputation* called_computation :
instruction->called_computations()) {
const std::vector<const HloInstruction*>& called_sequence =
FindOrDie(*module_sequence, called_computation);
TF_RETURN_IF_ERROR(RunComputation(*called_computation,
called_sequence, points_to_analysis,
module_sequence));
}
}
// Other sub-computations (e.g. Map, Reduce, ...) are skipped; they are
// assigned "thread-local" allocations, meaning their buffers are not
// allocated up-front at the beginning of the computation.
}
// Free buffers that are no longer live. This is the earliest point that we // Free buffers that are no longer live. This is the earliest point that we
// can de-allocate; right after the last use of the buffer. // can de-allocate; right after the last use of the buffer.
for (const LogicalBuffer* buffer : dead_buffers_to_free) { for (const LogicalBuffer* buffer : dead_buffers_to_free) {
heap.Free(buffer); Free(buffer);
} }
for (const LogicalBuffer* buffer : operand_buffers_to_free) { for (const LogicalBuffer* buffer : operand_buffers_to_free) {
heap.Free(buffer); Free(buffer);
} }
} }
@ -187,10 +244,10 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run(
const FlatSet<const HloInstruction*>& pending = buffer_pending.second; const FlatSet<const HloInstruction*>& pending = buffer_pending.second;
CHECK_EQ(pending.size(), 1) << *buffer; CHECK_EQ(pending.size(), 1) << *buffer;
CHECK(*pending.begin() == nullptr) << *buffer; CHECK(*pending.begin() == nullptr) << *buffer;
heap.Free(buffer); Free(buffer);
} }
return heap.Finish(); return Status::OK();
} }
HeapSimulator::HeapSimulator( HeapSimulator::HeapSimulator(
@ -309,6 +366,11 @@ HeapSimulator::Result HeapSimulator::Finish() {
result.chunk_map.emplace(buffer, chunk); result.chunk_map.emplace(buffer, chunk);
} }
} }
// If we were told to assign specific buffers, make sure we've assigned
// exactly that many buffers.
if (buffers_to_assign_ != nullptr) {
CHECK_EQ(buffers_to_assign_->size(), result.chunk_map.size());
}
} }
// Fragmentation is the difference between the actual and ideal sizes. // Fragmentation is the difference between the actual and ideal sizes.

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/logical_buffer.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/statusor.h"
@ -63,17 +64,32 @@ class HeapSimulator {
}; };
// Run the heap simulation with the given algorithm, assuming the given // Run the heap simulation with the given algorithm, assuming the given
// sequential ordering of instructions. The 'instruction_sequence' must // module_sequence, which must contain a topologically-consistent total
// contain a topologically-consistent total ordering of all instructions in // ordering of all instructions within each computation. The result is invalid
// the computation. The result is invalid if instructions are not run in // if instructions are not run in exactly this sequence.
// exactly this sequence. //
// Running heap simulation on the whole module tends to save memory, compared
// to running on a per-computation basis, since we can re-use buffer space for
// called sub-computations.
// //
// If 'buffers_to_assign' is provided, only those buffers are assigned // If 'buffers_to_assign' is provided, only those buffers are assigned
// offsets, otherwise all buffers defined by the instructions are assigned. // offsets, otherwise all buffers defined by the instructions are assigned.
static StatusOr<Result> Run(
std::unique_ptr<HeapAlgorithm> algorithm, const HloModule& module,
const SequentialHloOrdering::HloModuleSequence& module_sequence,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_fn,
const tensorflow::gtl::FlatSet<const LogicalBuffer*>* buffers_to_assign =
nullptr);
// Same as above, but runs on a single computation. The 'instruction_sequence'
// must contain a topologically-consistent total ordering of all instructions
// in the computation. The result is invalid if instructions are not run in
// exactly this sequence.
static StatusOr<Result> Run( static StatusOr<Result> Run(
std::unique_ptr<HeapAlgorithm> algorithm, std::unique_ptr<HeapAlgorithm> algorithm,
const std::vector<const HloInstruction*>& instruction_sequence,
const HloComputation& computation, const HloComputation& computation,
const std::vector<const HloInstruction*>& instruction_sequence,
const TuplePointsToAnalysis& points_to_analysis, const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_fn, const LogicalBuffer::SizeFunction& size_fn,
const tensorflow::gtl::FlatSet<const LogicalBuffer*>* buffers_to_assign = const tensorflow::gtl::FlatSet<const LogicalBuffer*>* buffers_to_assign =
@ -86,6 +102,12 @@ class HeapSimulator {
const tensorflow::gtl::FlatSet<const LogicalBuffer*>* buffers_to_assign); const tensorflow::gtl::FlatSet<const LogicalBuffer*>* buffers_to_assign);
~HeapSimulator(); ~HeapSimulator();
Status RunComputation(
const HloComputation& computation,
const std::vector<const HloInstruction*>& instruction_sequence,
const TuplePointsToAnalysis& points_to_analysis,
const SequentialHloOrdering::HloModuleSequence* module_sequence);
bool IgnoreBuffer(const LogicalBuffer* buffer) const; bool IgnoreBuffer(const LogicalBuffer* buffer) const;
void Alloc(const LogicalBuffer* buffer); void Alloc(const LogicalBuffer* buffer);
void Free(const LogicalBuffer* buffer); void Free(const LogicalBuffer* buffer);

View File

@ -19,13 +19,16 @@ limitations under the License.
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/logical_buffer.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
namespace xla { namespace xla {
namespace { namespace {
@ -69,6 +72,7 @@ class HeapCallRecorder : public HeapAlgorithm {
// sequence against an expected sequence. // sequence against an expected sequence.
class HeapSimulatorTracker { class HeapSimulatorTracker {
public: public:
// Constructor for testing a single entry computation.
HeapSimulatorTracker( HeapSimulatorTracker(
const string& name, std::unique_ptr<HloComputation> computation, const string& name, std::unique_ptr<HloComputation> computation,
const std::vector<const HloInstruction*>& instruction_sequence) { const std::vector<const HloInstruction*>& instruction_sequence) {
@ -83,12 +87,48 @@ class HeapSimulatorTracker {
auto zero_size = [](const LogicalBuffer& buffer) { return 0; }; auto zero_size = [](const LogicalBuffer& buffer) { return 0; };
auto algorithm = MakeUnique<DecreasingSizeRunsHeap>( auto algorithm = MakeUnique<DecreasingSizeRunsHeap>(
MakeUnique<HeapCallRecorder>(&actual_calls_)); MakeUnique<HeapCallRecorder>(&actual_calls_));
result_ = HeapSimulator::Run(std::move(algorithm), instruction_sequence, result_ = HeapSimulator::Run(
*module_->entry_computation(), std::move(algorithm), *module_->entry_computation(),
*points_to_analysis_, zero_size) instruction_sequence, *points_to_analysis_, zero_size)
.ConsumeValueOrDie(); .ConsumeValueOrDie();
} }
explicit HeapSimulatorTracker(const string& name) {
module_ = MakeUnique<HloModule>(name);
}
// Similar to the single entry computation constructor above, but runs the
// simulation over the entire module.
void RunWholeModule(
const std::vector<const HloInstruction*>& full_module_sequence) {
points_to_analysis_ =
TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
// Construct the module sequence grouped by computation.
SequentialHloOrdering::HloModuleSequence module_sequence;
tensorflow::gtl::FlatMap<const HloInstruction*, int> reverse_position;
for (int i = 0; i < full_module_sequence.size(); ++i) {
const HloInstruction* instruction = full_module_sequence[i];
module_sequence[instruction->parent()].push_back(instruction);
reverse_position[instruction] = full_module_sequence.size() - i;
}
// Hack the size_fn so that it returns a decreasing value as we step through
// the sequence. This lets us ensure the Alloc calls are in the sequence
// order. The Free calls are sorted by LogicalBuffer.id, which is at least
// deterministic.
auto size_fn = [&reverse_position](const LogicalBuffer& buffer) {
return reverse_position[buffer.instruction()];
};
auto algorithm = MakeUnique<DecreasingSizeRunsHeap>(
MakeUnique<HeapCallRecorder>(&actual_calls_));
result_ = HeapSimulator::Run(std::move(algorithm), *module_,
module_sequence, *points_to_analysis_, size_fn)
.ConsumeValueOrDie();
}
HloModule* module() { return module_.get(); }
// Returns the buffer defined at the given instruction and index. // Returns the buffer defined at the given instruction and index.
const LogicalBuffer* BufferAt(const HloInstruction* instruction, const LogicalBuffer* BufferAt(const HloInstruction* instruction,
const ShapeIndex& index) const { const ShapeIndex& index) const {
@ -358,6 +398,86 @@ TEST_F(HeapSimulatorTest, MultiplyDotDotTuple) {
}); });
} }
TEST_F(HeapSimulatorTest, WholeModule) {
HeapSimulatorTracker tracker(TestName());
const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
const Shape tuple_shape =
ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape});
auto cond_builder = HloComputation::Builder("WhileCond");
HloInstruction* cond_param = cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
HloInstruction* cond_iter = cond_builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0));
HloInstruction* cond_data = cond_builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
HloInstruction* cond_lt = cond_builder.AddInstruction(
HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}),
HloOpcode::kLt, cond_iter, cond_data));
HloComputation* cond_computation =
tracker.module()->AddEmbeddedComputation(cond_builder.Build());
auto body_builder = HloComputation::Builder("WhileBody");
HloInstruction* body_param = body_builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
HloComputation* body_computation =
tracker.module()->AddEmbeddedComputation(body_builder.Build());
auto builder = HloComputation::Builder(TestName());
HloInstruction* param = builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "param"));
HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
tuple_shape, cond_computation, body_computation, param));
tracker.module()->AddEntryComputation(builder.Build());
tracker.RunWholeModule(
{param, while_op, body_param, cond_param, cond_iter, cond_data, cond_lt});
tracker.ExpectCallSequence({
// The entry computation param and while_op are allocated first.
{kAlloc, tracker.BufferAt(param, {})},
{kAlloc, tracker.BufferAt(param, {0})},
{kAlloc, tracker.BufferAt(param, {1})},
{kAlloc, tracker.BufferAt(while_op, {})},
{kAlloc, tracker.BufferAt(while_op, {0})},
{kAlloc, tracker.BufferAt(while_op, {1})},
// Now the while body param is allocated and freed.
{kAlloc, tracker.BufferAt(body_param, {})},
{kAlloc, tracker.BufferAt(body_param, {0})},
{kAlloc, tracker.BufferAt(body_param, {1})},
{kFree, tracker.BufferAt(body_param, {})},
{kFree, tracker.BufferAt(body_param, {0})},
{kFree, tracker.BufferAt(body_param, {1})},
// Now the while cond param is allocated. The GTE instructions just alias
// the param elements, so the param tuple can immediately be freed.
{kAlloc, tracker.BufferAt(cond_param, {})},
{kAlloc, tracker.BufferAt(cond_param, {0})},
{kAlloc, tracker.BufferAt(cond_param, {1})},
{kFree, tracker.BufferAt(cond_param, {})},
// Now the final cond less-than buffer is allocated.
{kAlloc, tracker.BufferAt(cond_lt, {})},
// The order of the remaining Free calls is based on the LogicalBuffer.id,
// which is deterministic, but not obvious.
{kFree, tracker.BufferAt(param, {})},
{kFree, tracker.BufferAt(param, {0})},
{kFree, tracker.BufferAt(param, {1})},
{kFree, tracker.BufferAt(while_op, {})},
{kFree, tracker.BufferAt(while_op, {0})},
{kFree, tracker.BufferAt(while_op, {1})},
{kFree, tracker.BufferAt(cond_param, {0})},
{kFree, tracker.BufferAt(cond_param, {1})},
{kFree, tracker.BufferAt(cond_lt, {})},
{kFinish, nullptr},
});
}
// Base class for heap algorithm tests. // Base class for heap algorithm tests.
class HeapAlgorithmTestBase : public ::testing::Test { class HeapAlgorithmTestBase : public ::testing::Test {
protected: protected:

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/types.h"
namespace op = xla::testing::opcode_matchers; namespace op = xla::testing::opcode_matchers;
@ -49,8 +50,9 @@ TEST_F(HloConstantFoldingTest, ConvertF32ToS64) {
EXPECT_THAT(computation->root_instruction(), op::Convert(input)); EXPECT_THAT(computation->root_instruction(), op::Convert(input));
HloConstantFolding simplifier; HloConstantFolding const_folder;
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get()));
EXPECT_TRUE(result);
EXPECT_THAT(computation->root_instruction(), op::Constant()); EXPECT_THAT(computation->root_instruction(), op::Constant());
EXPECT_EQ(LiteralUtil::GetFirstElement<int64>( EXPECT_EQ(LiteralUtil::GetFirstElement<int64>(
@ -70,8 +72,9 @@ TEST_F(HloConstantFoldingTest, ConvertS64ToF32) {
EXPECT_THAT(computation->root_instruction(), op::Convert(input)); EXPECT_THAT(computation->root_instruction(), op::Convert(input));
HloConstantFolding simplifier; HloConstantFolding const_folder;
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get()));
EXPECT_TRUE(result);
EXPECT_THAT(computation->root_instruction(), op::Constant()); EXPECT_THAT(computation->root_instruction(), op::Constant());
EXPECT_EQ(LiteralUtil::GetFirstElement<float>( EXPECT_EQ(LiteralUtil::GetFirstElement<float>(
@ -91,8 +94,9 @@ TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) {
EXPECT_THAT(computation->root_instruction(), op::Convert(input)); EXPECT_THAT(computation->root_instruction(), op::Convert(input));
HloConstantFolding simplifier; HloConstantFolding const_folder;
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get()));
EXPECT_TRUE(result);
EXPECT_THAT(computation->root_instruction(), op::Constant()); EXPECT_THAT(computation->root_instruction(), op::Constant());
EXPECT_EQ( EXPECT_EQ(
@ -131,11 +135,12 @@ TEST_F(HloConstantFoldingTest, Concatenate) {
Shape shape = ShapeUtil::MakeShape(F32, dimensions); Shape shape = ShapeUtil::MakeShape(F32, dimensions);
builder.AddInstruction(HloInstruction::CreateConcatenate( builder.AddInstruction(HloInstruction::CreateConcatenate(
shape, operands, test_config.concat_dimension)); shape, operands, test_config.concat_dimension));
HloModule module(TestName()); auto module = MakeUnique<HloModule>(TestName());
auto computation = module.AddEntryComputation(builder.Build()); auto computation = module->AddEntryComputation(builder.Build());
HloConstantFolding simplifier; HloConstantFolding const_folder;
ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get()));
EXPECT_TRUE(result);
HloInstruction* root = computation->root_instruction(); HloInstruction* root = computation->root_instruction();
EXPECT_THAT(root, op::Constant()); EXPECT_THAT(root, op::Constant());
@ -148,22 +153,61 @@ TEST_F(HloConstantFoldingTest, Slice) {
const int64 dimensions[] = {11, 8, 7, 5, 9}; const int64 dimensions[] = {11, 8, 7, 5, 9};
const int64 slice_start[] = {4, 2, 3, 1, 5}; const int64 slice_start[] = {4, 2, 3, 1, 5};
const int64 slice_limits[] = {10, 8, 6, 5, 9}; const int64 slice_limits[] = {10, 8, 6, 5, 9};
auto literal = LiteralUtil::CreateFromDimensions(F32, dimensions); TF_ASSIGN_OR_ASSERT_OK(auto literal,
HloInstruction* lit_insn = builder.AddInstruction( LiteralTestUtil::CreateRandomLiteral<F32>(
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
HloInstruction* literal_instruction = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal))); HloInstruction::CreateConstant(std::move(literal)));
Shape shape = ShapeUtil::MakeShape(F32, {6, 6, 3, 4, 4}); Shape shape = ShapeUtil::MakeShape(F32, {6, 6, 3, 4, 4});
builder.AddInstruction( builder.AddInstruction(HloInstruction::CreateSlice(
HloInstruction::CreateSlice(shape, lit_insn, slice_start, slice_limits)); shape, literal_instruction, slice_start, slice_limits));
HloModule module(TestName()); auto module = MakeUnique<HloModule>(TestName());
auto computation = module.AddEntryComputation(builder.Build()); auto computation = module->AddEntryComputation(builder.Build());
HloConstantFolding simplifier; HloConstantFolding const_folder;
ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get()));
EXPECT_TRUE(result);
HloInstruction* root = computation->root_instruction(); HloInstruction* root = computation->root_instruction();
EXPECT_THAT(root, op::Constant()); EXPECT_THAT(root, op::Constant());
EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape)); EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape));
} }
TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
HloComputation::Builder builder(TestName());
const int64 dimensions[] = {11, 8, 7, 5, 9};
TF_ASSIGN_OR_ASSERT_OK(auto literal,
LiteralTestUtil::CreateRandomLiteral<F32>(
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
auto literal_clone = LiteralUtil::CloneToUnique(*literal);
HloInstruction* literal_instruction = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5});
const int64 permutation[] = {1, 2, 0, 4, 3};
builder.AddInstruction(
HloInstruction::CreateTranspose(shape, literal_instruction, permutation));
auto module = MakeUnique<HloModule>(TestName());
auto computation = module->AddEntryComputation(builder.Build());
HloConstantFolding const_folder;
TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get()));
EXPECT_TRUE(result);
HloInstruction* root = computation->root_instruction();
EXPECT_THAT(root, op::Constant());
EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape));
using NativeT = typename primitive_util::PrimitiveTypeToNative<F32>::type;
bool matched = true;
LiteralUtil::EachCell<NativeT>(
root->literal(),
[&](tensorflow::gtl::ArraySlice<int64> indices, NativeT value) {
std::vector<int64> rindexes = Permute(permutation, indices);
matched = matched && (value == LiteralUtil::Get<NativeT>(*literal_clone,
rindexes));
});
EXPECT_TRUE(matched);
}
} // namespace } // namespace
} // namespace xla } // namespace xla

View File

@ -1570,7 +1570,9 @@ string HloInstruction::ToCategory() const {
return "non-elementwise fusion"; return "non-elementwise fusion";
} }
case FusionKind::kInput: case FusionKind::kInput:
return "reduce fusion"; return "input fusion";
case FusionKind::kOutput:
return "output fusion";
case FusionKind::kTransposeDot: case FusionKind::kTransposeDot:
return "dot fusion"; return "dot fusion";
case FusionKind::kConvBackwardFilter: case FusionKind::kConvBackwardFilter:
@ -1618,7 +1620,6 @@ bool HloInstruction::IsFusable() const {
// Some kinds of instructions don't make sense to fuse. // Some kinds of instructions don't make sense to fuse.
switch (opcode_) { switch (opcode_) {
case HloOpcode::kFusion:
case HloOpcode::kInfeed: case HloOpcode::kInfeed:
case HloOpcode::kOutfeed: case HloOpcode::kOutfeed:
case HloOpcode::kParameter: case HloOpcode::kParameter:
@ -2186,6 +2187,8 @@ string ToString(HloInstruction::FusionKind kind) {
return "kLoop"; return "kLoop";
case HloInstruction::FusionKind::kInput: case HloInstruction::FusionKind::kInput:
return "kInput"; return "kInput";
case HloInstruction::FusionKind::kOutput:
return "kOutput";
case HloInstruction::FusionKind::kTransposeDot: case HloInstruction::FusionKind::kTransposeDot:
return "kTransposeDot"; return "kTransposeDot";
case HloInstruction::FusionKind::kConvBackwardFilter: case HloInstruction::FusionKind::kConvBackwardFilter:

View File

@ -54,7 +54,8 @@ class HloInstruction {
public: public:
enum class FusionKind { enum class FusionKind {
kLoop, // Fused into a loop. kLoop, // Fused into a loop.
kInput, // Fused into a reduction kernel. kInput, // Op's input is fused into the op itself.
kOutput, // Op's output is fused into the op itself.
kTransposeDot, // Fused into a dot with transposed operands. kTransposeDot, // Fused into a dot with transposed operands.
kConvBackwardFilter, // Fused into a backward filter convolution. kConvBackwardFilter, // Fused into a backward filter convolution.
kConvBackwardInput, // Fused into a backward input convolution. kConvBackwardInput, // Fused into a backward input convolution.

View File

@ -221,23 +221,6 @@ string SequentialHloOrdering::ToString() const {
return tensorflow::str_util::Join(pieces, "\n"); return tensorflow::str_util::Join(pieces, "\n");
} }
namespace {
StatusOr<int64> MinimumMemoryForSequence(
const HloComputation& computation,
const std::vector<const HloInstruction*>& sequence,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function) {
// The absolute minimum memory required for a given sequence of instructions
// is determined by the sequence of Alloc and Free calls on a simulated heap,
// ignoring fragmentation.
TF_ASSIGN_OR_RETURN(
HeapSimulator::Result result,
HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), sequence,
computation, points_to_analysis, size_function));
return result.heap_size;
}
} // namespace
StatusOr<int64> MinimumMemoryForSequence( StatusOr<int64> MinimumMemoryForSequence(
const SequentialHloOrdering::HloModuleSequence& module_sequence, const SequentialHloOrdering::HloModuleSequence& module_sequence,
const LogicalBuffer::SizeFunction& size_function) { const LogicalBuffer::SizeFunction& size_function) {
@ -249,17 +232,16 @@ StatusOr<int64> MinimumMemoryForSequence(
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis, TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
TuplePointsToAnalysis::Run(module)); TuplePointsToAnalysis::Run(module));
int64 total_memory = 0; // The absolute minimum memory required for a given sequence of instructions
for (const auto& pair : module_sequence) { // is determined by the sequence of Alloc and Free calls on a simulated heap,
const HloComputation* computation = pair.first; // ignoring fragmentation. We run the heap simulation on the whole module,
const std::vector<const HloInstruction*>& sequence = pair.second; // rather than summing each computation, since it gives us a better lower
TF_ASSIGN_OR_RETURN( // bound, by minimizing the liveness of sub-computations.
const int64 memory, TF_ASSIGN_OR_RETURN(
MinimumMemoryForSequence(*computation, sequence, *points_to_analysis, HeapSimulator::Result result,
size_function)); HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), *module,
total_memory += memory; module_sequence, *points_to_analysis, size_function));
} return result.heap_size;
return total_memory;
} }
namespace { namespace {
@ -516,6 +498,18 @@ StatusOr<std::vector<const HloInstruction*>> RunDFSMemoryScheduler(
return sequence; return sequence;
} }
StatusOr<int64> MinimumMemoryForComputation(
const HloComputation& computation,
const std::vector<const HloInstruction*>& sequence,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function) {
TF_ASSIGN_OR_RETURN(
HeapSimulator::Result result,
HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), computation,
sequence, points_to_analysis, size_function));
return result.heap_size;
}
StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence( StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence(
const HloComputation& computation, const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis, const TuplePointsToAnalysis& points_to_analysis,
@ -523,13 +517,17 @@ StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence(
// We try both a list-scheduler based ordering and a DFS based ordering, and // We try both a list-scheduler based ordering and a DFS based ordering, and
// choose whichever returns a lower min-memory, not accounting for // choose whichever returns a lower min-memory, not accounting for
// fragmentation. // fragmentation.
//
// Note that this is just a heuristic. One obvious inaccuracy is that the
// memory required for sub-computations might be different when considered
// within the caller's context. But it's good enough for now.
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
std::vector<const HloInstruction*> list_sequence, std::vector<const HloInstruction*> list_sequence,
ListScheduler::Run(computation, points_to_analysis, size_function)); ListScheduler::Run(computation, points_to_analysis, size_function));
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
const int64 list_memory, const int64 list_memory,
MinimumMemoryForSequence(computation, list_sequence, points_to_analysis, MinimumMemoryForComputation(computation, list_sequence,
size_function)); points_to_analysis, size_function));
VLOG(2) << "Min-memory list sequence: " << list_memory << " bytes"; VLOG(2) << "Min-memory list sequence: " << list_memory << " bytes";
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
@ -537,8 +535,8 @@ StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence(
RunDFSMemoryScheduler(computation, points_to_analysis, size_function)); RunDFSMemoryScheduler(computation, points_to_analysis, size_function));
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
const int64 dfs_memory, const int64 dfs_memory,
MinimumMemoryForSequence(computation, dfs_sequence, points_to_analysis, MinimumMemoryForComputation(computation, dfs_sequence, points_to_analysis,
size_function)); size_function));
VLOG(2) << "Min-memory dfs sequence: " << dfs_memory << " bytes"; VLOG(2) << "Min-memory dfs sequence: " << dfs_memory << " bytes";
if (list_memory <= dfs_memory) { if (list_memory <= dfs_memory) {

View File

@ -155,6 +155,65 @@ TEST_F(HloOrderingTest, InstructionsInDifferentComputations) {
EXPECT_FALSE(ordering.ExecutesBefore(y, c)); EXPECT_FALSE(ordering.ExecutesBefore(y, c));
} }
class MinimumMemoryForSequenceTest : public HloTestBase {};
TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
HloModule module(TestName());
const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
const Shape tuple_shape =
ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape});
auto cond_builder = HloComputation::Builder("WhileCond");
// Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
HloInstruction* cond_param = cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
HloInstruction* cond_iter = cond_builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0));
HloInstruction* cond_data = cond_builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
// Free cond_param[] (16 bytes), Alloc PRED[] (1 byte)
HloInstruction* cond_lt = cond_builder.AddInstruction(
HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}),
HloOpcode::kLt, cond_iter, cond_data));
HloComputation* cond_computation =
module.AddEmbeddedComputation(cond_builder.Build());
auto body_builder = HloComputation::Builder("WhileBody");
// Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
HloInstruction* body_param = body_builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
HloComputation* body_computation =
module.AddEmbeddedComputation(body_builder.Build());
auto builder = HloComputation::Builder(TestName());
// Entry params: 8 bytes (4 bytes per param), TOTAL=8
HloInstruction* iter = builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape, "param_iter"));
HloInstruction* data = builder.AddInstruction(
HloInstruction::CreateParameter(1, scalar_shape, "param_data"));
// Tuple: 16 bytes (8 bytes per pointer), TOTAL=24
HloInstruction* tuple =
builder.AddInstruction(HloInstruction::CreateTuple({iter, data}));
// While: 8 bytes (4 bytes per element), TOTAL=32
// Both cond and body use a max of 24 bytes, TOTAL=56
HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
tuple_shape, cond_computation, body_computation, tuple));
HloComputation* entry_computation =
module.AddEntryComputation(builder.Build());
auto size_fn = [](const LogicalBuffer& buffer) {
return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
};
SequentialHloOrdering::HloModuleSequence module_sequence;
module_sequence[cond_computation] = {cond_param, cond_iter, cond_data,
cond_lt};
module_sequence[body_computation] = {body_param};
module_sequence[entry_computation] = {iter, data, tuple, while_op};
EXPECT_EQ(56,
MinimumMemoryForSequence(module_sequence, size_fn).ValueOrDie());
}
} // namespace } // namespace
} // namespace xla } // namespace xla

View File

@ -1160,28 +1160,25 @@ StatusOr<bool> HloRematerialization::Run(
TuplePointsToAnalysis::Run( TuplePointsToAnalysis::Run(
module, /*include_loop_fusion_instructions=*/true)); module, /*include_loop_fusion_instructions=*/true));
// Adjust memory limit to account for the parameter and output of the entry // Adjust memory limit to account for the output of the entry
// computation. This is necessary because the per-computation accounting in // computation. This is necessary because the per-computation accounting in
// MemoryUsageTracker do not include parameters and output as these are // MemoryUsageTracker do not include output as these are typically allocated
// typically allocated by the caller. With this adjustment the memory limit // by the caller.
// accounts for the size of all HLO instructions (parameters, output int64 module_output_size = 0;
// instructions, etc). ShapeUtil::ForEachSubshape(
auto total_size = [this](const HloInstruction* instruction) { module->entry_computation()->root_instruction()->shape(),
int64 total_size = 0; [&module_output_size, this](const Shape& subshape,
for (const LogicalBuffer* logical_buffer : const ShapeIndex& /*index*/) {
points_to_analysis_->GetBuffersDefinedByInstruction(instruction)) { module_output_size += size_function_(subshape);
total_size += size_function_(logical_buffer->shape()); return Status::OK();
} })
return total_size; .IgnoreError();
};
const HloComputation* entry_computation = module->entry_computation(); const int64 adjusted_memory_limit_bytes =
memory_limit_bytes -= total_size(entry_computation->root_instruction()); memory_limit_bytes - module_output_size;
for (const HloInstruction* param : VLOG(1) << "Adjusted memory limit accounting for output ("
entry_computation->parameter_instructions()) { << HumanReadableNumBytes(module_output_size)
memory_limit_bytes -= total_size(param); << "): " << HumanReadableNumBytes(adjusted_memory_limit_bytes);
}
VLOG(1) << "Adjusted memory limit accounting for parameters and output: "
<< HumanReadableNumBytes(memory_limit_bytes);
XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString()); XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString());
// Create initial sequence of HLO instructions. // Create initial sequence of HLO instructions.
@ -1204,8 +1201,13 @@ StatusOr<bool> HloRematerialization::Run(
return Status::OK(); return Status::OK();
})); }));
// The peak memory usage of the module equals the peak memory use of the entry
// computation plus the output size of the computation. This is because the
// peak memory for a computation does not include the output as this is
// typically accounted for in the caller.
const int64 before_peak_memory = const int64 before_peak_memory =
computation_peak_memory_.at(module->entry_computation()); computation_peak_memory_.at(module->entry_computation()) +
module_output_size;
VLOG(1) << "Peak memory usage of module (before): " VLOG(1) << "Peak memory usage of module (before): "
<< HumanReadableNumBytes(before_peak_memory); << HumanReadableNumBytes(before_peak_memory);
@ -1216,9 +1218,9 @@ StatusOr<bool> HloRematerialization::Run(
// Subcomputations called by the entry computation will also be // Subcomputations called by the entry computation will also be
// rematerialized. // rematerialized.
TF_ASSIGN_OR_RETURN(bool changed, TF_ASSIGN_OR_RETURN(bool changed, RematerializeComputation(
RematerializeComputation(module->entry_computation(), module->entry_computation(), sequence,
sequence, memory_limit_bytes)); adjusted_memory_limit_bytes));
// Rematerialization can introduce dead code. This occurs if all uses of an // Rematerialization can introduce dead code. This occurs if all uses of an
// instruction are replaced with rematerializations of the instruction. // instruction are replaced with rematerializations of the instruction.
@ -1257,7 +1259,8 @@ StatusOr<bool> HloRematerialization::Run(
<< " instructions in module " << module->name() << "; " << " instructions in module " << module->name() << "; "
<< net_instructions_added_ << " net instructions added"; << net_instructions_added_ << " net instructions added";
const int64 current_peak_memory = const int64 current_peak_memory =
computation_peak_memory_.at(module->entry_computation()); computation_peak_memory_.at(module->entry_computation()) +
module_output_size;
VLOG(1) << "Peak memory usage of module now " VLOG(1) << "Peak memory usage of module now "
<< HumanReadableNumBytes(current_peak_memory) << " (" << HumanReadableNumBytes(current_peak_memory) << " ("
<< current_peak_memory << " bytes), was " << current_peak_memory << " bytes), was "

View File

@ -1928,6 +1928,12 @@ HloInstruction* ComputationLowerer::Visit(
const OperationRequest& request = const OperationRequest& request =
session_computation_.requests().at(handle.handle()); session_computation_.requests().at(handle.handle());
auto add_instruction = [&](std::unique_ptr<HloInstruction> instruction) {
HloInstruction* hlo_instruction =
hlo_builder_.AddInstruction(std::move(instruction));
hlo_instruction->set_metadata(request.request().metadata());
return hlo_instruction;
};
HloInstruction* hlo_instruction; HloInstruction* hlo_instruction;
switch (request.request().op_case()) { switch (request.request().op_case()) {
case OpRequest::kRngRequest: { case OpRequest::kRngRequest: {
@ -1936,7 +1942,7 @@ HloInstruction* ComputationLowerer::Visit(
for (const ComputationDataHandle& param : rng_request.parameter()) { for (const ComputationDataHandle& param : rng_request.parameter()) {
parameters.push_back(Visit(param, visited)); parameters.push_back(Visit(param, visited));
} }
hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateRng( hlo_instruction = add_instruction(HloInstruction::CreateRng(
request.output_shape(), rng_request.distribution(), parameters)); request.output_shape(), rng_request.distribution(), parameters));
break; break;
} }
@ -1944,9 +1950,8 @@ HloInstruction* ComputationLowerer::Visit(
case OpRequest::kConstantRequest: { case OpRequest::kConstantRequest: {
const ConstantRequest& constant_request = const ConstantRequest& constant_request =
request.request().constant_request(); request.request().constant_request();
hlo_instruction = hlo_instruction = add_instruction(HloInstruction::CreateConstant(
hlo_builder_.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CloneToUnique(constant_request.literal())));
LiteralUtil::CloneToUnique(constant_request.literal())));
break; break;
} }
@ -1955,17 +1960,15 @@ HloInstruction* ComputationLowerer::Visit(
request.request().get_tuple_element_request(); request.request().get_tuple_element_request();
HloInstruction* operand = HloInstruction* operand =
Visit(get_tuple_element_request.operand(), visited); Visit(get_tuple_element_request.operand(), visited);
hlo_instruction = hlo_instruction = add_instruction(HloInstruction::CreateGetTupleElement(
hlo_builder_.AddInstruction(HloInstruction::CreateGetTupleElement( request.output_shape(), operand, get_tuple_element_request.index()));
request.output_shape(), operand,
get_tuple_element_request.index()));
break; break;
} }
case OpRequest::kSliceRequest: { case OpRequest::kSliceRequest: {
const SliceRequest& slice_request = request.request().slice_request(); const SliceRequest& slice_request = request.request().slice_request();
HloInstruction* operand = Visit(slice_request.operand(), visited); HloInstruction* operand = Visit(slice_request.operand(), visited);
hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateSlice( hlo_instruction = add_instruction(HloInstruction::CreateSlice(
request.output_shape(), operand, request.output_shape(), operand,
AsInt64Slice(slice_request.start_indices()), AsInt64Slice(slice_request.start_indices()),
AsInt64Slice(slice_request.limit_indices()))); AsInt64Slice(slice_request.limit_indices())));
@ -1979,10 +1982,9 @@ HloInstruction* ComputationLowerer::Visit(
HloInstruction* start_indices = HloInstruction* start_indices =
Visit(dynamic_slice_request.start_indices(), visited); Visit(dynamic_slice_request.start_indices(), visited);
hlo_instruction = hlo_instruction = add_instruction(HloInstruction::CreateDynamicSlice(
hlo_builder_.AddInstruction(HloInstruction::CreateDynamicSlice( request.output_shape(), operand, start_indices,
request.output_shape(), operand, start_indices, AsInt64Slice(dynamic_slice_request.slice_sizes())));
AsInt64Slice(dynamic_slice_request.slice_sizes())));
break; break;
} }
@ -1996,7 +1998,7 @@ HloInstruction* ComputationLowerer::Visit(
HloInstruction* start_indices = HloInstruction* start_indices =
Visit(dynamic_update_slice_request.start_indices(), visited); Visit(dynamic_update_slice_request.start_indices(), visited);
hlo_instruction = hlo_instruction =
hlo_builder_.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( add_instruction(HloInstruction::CreateDynamicUpdateSlice(
request.output_shape(), operand, update, start_indices)); request.output_shape(), operand, update, start_indices));
break; break;
} }
@ -2010,9 +2012,8 @@ HloInstruction* ComputationLowerer::Visit(
HloInstruction* operand = Visit(handle, visited); HloInstruction* operand = Visit(handle, visited);
operands.push_back(operand); operands.push_back(operand);
} }
hlo_instruction = hlo_builder_.AddInstruction( hlo_instruction = add_instruction(HloInstruction::CreateConcatenate(
HloInstruction::CreateConcatenate(request.output_shape(), operands, request.output_shape(), operands, concatenate_request.dimension()));
concatenate_request.dimension()));
break; break;
} }
@ -2021,10 +2022,9 @@ HloInstruction* ComputationLowerer::Visit(
request.request().convolve_request(); request.request().convolve_request();
HloInstruction* lhs = Visit(convolve_request.lhs(), visited); HloInstruction* lhs = Visit(convolve_request.lhs(), visited);
HloInstruction* rhs = Visit(convolve_request.rhs(), visited); HloInstruction* rhs = Visit(convolve_request.rhs(), visited);
hlo_instruction = hlo_instruction = add_instruction(HloInstruction::CreateConvolve(
hlo_builder_.AddInstruction(HloInstruction::CreateConvolve( request.output_shape(), lhs, rhs, convolve_request.window(),
request.output_shape(), lhs, rhs, convolve_request.window(), convolve_request.dimension_numbers()));
convolve_request.dimension_numbers()));
break; break;
} }
@ -2033,17 +2033,15 @@ HloInstruction* ComputationLowerer::Visit(
request.request().cross_replica_sum_request(); request.request().cross_replica_sum_request();
HloInstruction* operand = HloInstruction* operand =
Visit(cross_replica_sum_request.operand(), visited); Visit(cross_replica_sum_request.operand(), visited);
hlo_instruction = hlo_instruction = add_instruction(HloInstruction::CreateCrossReplicaSum(
hlo_builder_.AddInstruction(HloInstruction::CreateCrossReplicaSum( request.output_shape(), operand));
request.output_shape(), operand));
break; break;
} }
case OpRequest::kInfeedRequest: { case OpRequest::kInfeedRequest: {
const InfeedRequest& infeed_request = request.request().infeed_request(); const InfeedRequest& infeed_request = request.request().infeed_request();
hlo_instruction = hlo_instruction = add_instruction(HloInstruction::CreateInfeed(
hlo_builder_.AddInstruction(HloInstruction::CreateInfeed( request.output_shape(), infeed_request.config()));
request.output_shape(), infeed_request.config()));
break; break;
} }
@ -2051,9 +2049,8 @@ HloInstruction* ComputationLowerer::Visit(
const OutfeedRequest& outfeed_request = const OutfeedRequest& outfeed_request =
request.request().outfeed_request(); request.request().outfeed_request();
HloInstruction* operand = Visit(outfeed_request.operand(), visited); HloInstruction* operand = Visit(outfeed_request.operand(), visited);
hlo_instruction = hlo_builder_.AddInstruction( hlo_instruction = add_instruction(HloInstruction::CreateOutfeed(
HloInstruction::CreateOutfeed(outfeed_request.shape(), operand, outfeed_request.shape(), operand, outfeed_request.outfeed_config()));
outfeed_request.outfeed_config()));
break; break;
} }
@ -2069,7 +2066,7 @@ HloInstruction* ComputationLowerer::Visit(
request.embedded_computation_versions(0); request.embedded_computation_versions(0);
HloComputation* map_computation = HloComputation* map_computation =
ResolveComputation(map_request.to_apply(), map_version); ResolveComputation(map_request.to_apply(), map_version);
hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateMap( hlo_instruction = add_instruction(HloInstruction::CreateMap(
request.output_shape(), operands, map_computation)); request.output_shape(), operands, map_computation));
break; break;
} }
@ -2083,10 +2080,9 @@ HloInstruction* ComputationLowerer::Visit(
request.embedded_computation_versions(0); request.embedded_computation_versions(0);
HloComputation* reduce_computation = HloComputation* reduce_computation =
ResolveComputation(reduce_request.to_apply(), reduce_version); ResolveComputation(reduce_request.to_apply(), reduce_version);
hlo_instruction = hlo_instruction = add_instruction(HloInstruction::CreateReduce(
hlo_builder_.AddInstruction(HloInstruction::CreateReduce( request.output_shape(), operand, init_value,
request.output_shape(), operand, init_value, AsInt64Slice(reduce_request.dimensions()), reduce_computation));
AsInt64Slice(reduce_request.dimensions()), reduce_computation));
break; break;
} }
@ -2101,10 +2097,9 @@ HloInstruction* ComputationLowerer::Visit(
request.embedded_computation_versions(0); request.embedded_computation_versions(0);
HloComputation* reduce_window_computation = ResolveComputation( HloComputation* reduce_window_computation = ResolveComputation(
reduce_window_request.to_apply(), reduce_window_version); reduce_window_request.to_apply(), reduce_window_version);
hlo_instruction = hlo_instruction = add_instruction(HloInstruction::CreateReduceWindow(
hlo_builder_.AddInstruction(HloInstruction::CreateReduceWindow( request.output_shape(), operand, init_value,
request.output_shape(), operand, init_value, reduce_window_request.window(), reduce_window_computation));
reduce_window_request.window(), reduce_window_computation));
break; break;
} }
@ -2126,11 +2121,10 @@ HloInstruction* ComputationLowerer::Visit(
select_and_scatter_request.select(), select_version); select_and_scatter_request.select(), select_version);
HloComputation* scatter_computation = ResolveComputation( HloComputation* scatter_computation = ResolveComputation(
select_and_scatter_request.scatter(), scatter_version); select_and_scatter_request.scatter(), scatter_version);
hlo_instruction = hlo_instruction = add_instruction(HloInstruction::CreateSelectAndScatter(
hlo_builder_.AddInstruction(HloInstruction::CreateSelectAndScatter( request.output_shape(), operand, select_computation,
request.output_shape(), operand, select_computation, select_and_scatter_request.window(), source, init_value,
select_and_scatter_request.window(), source, init_value, scatter_computation));
scatter_computation));
break; break;
} }
@ -2151,9 +2145,8 @@ HloInstruction* ComputationLowerer::Visit(
ShapeUtil::Rank(request.output_shape()) - ShapeUtil::Rank(request.output_shape()) -
ShapeUtil::Rank(operand->shape())); ShapeUtil::Rank(operand->shape()));
} }
hlo_instruction = hlo_instruction = add_instruction(HloInstruction::CreateBroadcast(
hlo_builder_.AddInstruction(HloInstruction::CreateBroadcast( request.output_shape(), operand, broadcast_dimensions));
request.output_shape(), operand, broadcast_dimensions));
break; break;
} }
@ -2165,14 +2158,13 @@ HloInstruction* ComputationLowerer::Visit(
if (IsIdentityPermutation(AsInt64Slice(reshape_request.dimensions()))) { if (IsIdentityPermutation(AsInt64Slice(reshape_request.dimensions()))) {
transposed = operand; transposed = operand;
} else { } else {
transposed = transposed = add_instruction(HloInstruction::CreateTranspose(
hlo_builder_.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::PermuteDimensions(
ShapeUtil::PermuteDimensions(InversePermutation(AsInt64Slice( InversePermutation(AsInt64Slice(reshape_request.dimensions())),
reshape_request.dimensions())), operand->shape()),
operand->shape()), operand, AsInt64Slice(reshape_request.dimensions())));
operand, AsInt64Slice(reshape_request.dimensions())));
} }
hlo_instruction = hlo_builder_.AddInstruction( hlo_instruction = add_instruction(
HloInstruction::CreateReshape(request.output_shape(), transposed)); HloInstruction::CreateReshape(request.output_shape(), transposed));
break; break;
} }
@ -2181,12 +2173,11 @@ HloInstruction* ComputationLowerer::Visit(
const TransposeRequest& transpose_request = const TransposeRequest& transpose_request =
request.request().transpose_request(); request.request().transpose_request();
HloInstruction* operand = Visit(transpose_request.operand(), visited); HloInstruction* operand = Visit(transpose_request.operand(), visited);
hlo_instruction = hlo_instruction = add_instruction(HloInstruction::CreateTranspose(
hlo_builder_.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::PermuteDimensions(
ShapeUtil::PermuteDimensions(InversePermutation(AsInt64Slice( InversePermutation(AsInt64Slice(transpose_request.dimensions())),
transpose_request.dimensions())), operand->shape()),
operand->shape()), operand, AsInt64Slice(transpose_request.dimensions())));
operand, AsInt64Slice(transpose_request.dimensions())));
break; break;
} }
@ -2194,10 +2185,9 @@ HloInstruction* ComputationLowerer::Visit(
const ReverseRequest& reverse_request = const ReverseRequest& reverse_request =
request.request().reverse_request(); request.request().reverse_request();
HloInstruction* operand = Visit(reverse_request.operand(), visited); HloInstruction* operand = Visit(reverse_request.operand(), visited);
hlo_instruction = hlo_instruction = add_instruction(HloInstruction::CreateReverse(
hlo_builder_.AddInstruction(HloInstruction::CreateReverse( request.output_shape(), operand,
request.output_shape(), operand, AsInt64Slice(reverse_request.dimensions())));
AsInt64Slice(reverse_request.dimensions())));
break; break;
} }
@ -2206,7 +2196,7 @@ HloInstruction* ComputationLowerer::Visit(
HloInstruction* operand = Visit(pad_request.operand(), visited); HloInstruction* operand = Visit(pad_request.operand(), visited);
HloInstruction* padding_value = HloInstruction* padding_value =
Visit(pad_request.padding_value(), visited); Visit(pad_request.padding_value(), visited);
hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreatePad( hlo_instruction = add_instruction(HloInstruction::CreatePad(
request.output_shape(), operand, padding_value, request.output_shape(), operand, padding_value,
pad_request.padding_config())); pad_request.padding_config()));
break; break;
@ -2214,7 +2204,7 @@ HloInstruction* ComputationLowerer::Visit(
case OpRequest::kRecvRequest: { case OpRequest::kRecvRequest: {
const RecvRequest& recv_request = request.request().recv_request(); const RecvRequest& recv_request = request.request().recv_request();
hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateRecv( hlo_instruction = add_instruction(HloInstruction::CreateRecv(
request.output_shape(), recv_request.channel_handle().handle())); request.output_shape(), recv_request.channel_handle().handle()));
break; break;
} }
@ -2222,10 +2212,9 @@ HloInstruction* ComputationLowerer::Visit(
case OpRequest::kParameterRequest: { case OpRequest::kParameterRequest: {
const ParameterRequest& parameter_request = const ParameterRequest& parameter_request =
request.request().parameter_request(); request.request().parameter_request();
hlo_instruction = hlo_instruction = add_instruction(HloInstruction::CreateParameter(
hlo_builder_.AddInstruction(HloInstruction::CreateParameter( parameter_request.parameter(), request.output_shape(),
parameter_request.parameter(), request.output_shape(), parameter_request.name()));
parameter_request.name()));
break; break;
} }
@ -2233,7 +2222,7 @@ HloInstruction* ComputationLowerer::Visit(
const ConvertRequest& convert_request = const ConvertRequest& convert_request =
request.request().convert_request(); request.request().convert_request();
HloInstruction* operand = Visit(convert_request.operand(), visited); HloInstruction* operand = Visit(convert_request.operand(), visited);
hlo_instruction = hlo_builder_.AddInstruction( hlo_instruction = add_instruction(
HloInstruction::CreateConvert(request.output_shape(), operand)); HloInstruction::CreateConvert(request.output_shape(), operand));
break; break;
} }
@ -2250,7 +2239,7 @@ HloInstruction* ComputationLowerer::Visit(
HloComputation* body = HloComputation* body =
ResolveComputation(while_request.body(), body_version); ResolveComputation(while_request.body(), body_version);
HloInstruction* init = Visit(while_request.init(), visited); HloInstruction* init = Visit(while_request.init(), visited);
hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateWhile( hlo_instruction = add_instruction(HloInstruction::CreateWhile(
request.output_shape(), condition, body, init)); request.output_shape(), condition, body, init));
break; break;
} }
@ -2262,9 +2251,8 @@ HloInstruction* ComputationLowerer::Visit(
HloInstruction* rhs = Visit(ternary_op_request.rhs(), visited); HloInstruction* rhs = Visit(ternary_op_request.rhs(), visited);
HloInstruction* ehs = Visit(ternary_op_request.ehs(), visited); HloInstruction* ehs = Visit(ternary_op_request.ehs(), visited);
auto hlo_opcode = TernaryOperationToHloOpcode(ternary_op_request.triop()); auto hlo_opcode = TernaryOperationToHloOpcode(ternary_op_request.triop());
hlo_instruction = hlo_instruction = add_instruction(HloInstruction::CreateTernary(
hlo_builder_.AddInstruction(HloInstruction::CreateTernary( request.output_shape(), hlo_opcode, lhs, rhs, ehs));
request.output_shape(), hlo_opcode, lhs, rhs, ehs));
break; break;
} }
@ -2279,9 +2267,8 @@ HloInstruction* ComputationLowerer::Visit(
} }
auto hlo_opcode = auto hlo_opcode =
VariadicOperationToHloOpcode(variadic_op_request.varop()); VariadicOperationToHloOpcode(variadic_op_request.varop());
hlo_instruction = hlo_instruction = add_instruction(HloInstruction::CreateVariadic(
hlo_builder_.AddInstruction(HloInstruction::CreateVariadic( request.output_shape(), hlo_opcode, operands));
request.output_shape(), hlo_opcode, operands));
break; break;
} }
@ -2296,7 +2283,7 @@ HloInstruction* ComputationLowerer::Visit(
request.embedded_computation_versions(0); request.embedded_computation_versions(0);
HloComputation* call_computation = HloComputation* call_computation =
ResolveComputation(call_request.to_apply(), call_version); ResolveComputation(call_request.to_apply(), call_version);
hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateCall( hlo_instruction = add_instruction(HloInstruction::CreateCall(
request.output_shape(), operands, call_computation)); request.output_shape(), operands, call_computation));
break; break;
} }
@ -2308,9 +2295,8 @@ HloInstruction* ComputationLowerer::Visit(
for (const ComputationDataHandle& operand : cc_request.operands()) { for (const ComputationDataHandle& operand : cc_request.operands()) {
operands.push_back(Visit(operand, visited)); operands.push_back(Visit(operand, visited));
} }
hlo_instruction = hlo_instruction = add_instruction(HloInstruction::CreateCustomCall(
hlo_builder_.AddInstruction(HloInstruction::CreateCustomCall( cc_request.shape(), operands, cc_request.call_target_name()));
cc_request.shape(), operands, cc_request.call_target_name()));
break; break;
} }
@ -2319,7 +2305,7 @@ HloInstruction* ComputationLowerer::Visit(
request.request().unary_op_request(); request.request().unary_op_request();
HloInstruction* operand = Visit(unary_op_request.operand(), visited); HloInstruction* operand = Visit(unary_op_request.operand(), visited);
auto hlo_opcode = UnaryOperationToHloOpcode(unary_op_request.unop()); auto hlo_opcode = UnaryOperationToHloOpcode(unary_op_request.unop());
hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateUnary( hlo_instruction = add_instruction(HloInstruction::CreateUnary(
request.output_shape(), hlo_opcode, operand)); request.output_shape(), hlo_opcode, operand));
break; break;
} }
@ -2347,23 +2333,22 @@ HloInstruction* ComputationLowerer::Visit(
// identical to the HLO broadcast semantics so the broadcast_dimensions // identical to the HLO broadcast semantics so the broadcast_dimensions
// field can just be passed to the instruction builder. // field can just be passed to the instruction builder.
HloInstruction* broadcasted_operand = HloInstruction* broadcasted_operand =
hlo_builder_.AddInstruction(HloInstruction::CreateBroadcast( add_instruction(HloInstruction::CreateBroadcast(
broadcast_shape, operand_to_broadcast, broadcast_shape, operand_to_broadcast,
AsInt64Slice(binary_op_request.broadcast_dimensions()))); AsInt64Slice(binary_op_request.broadcast_dimensions())));
lhs = (lhs == operand_to_broadcast) ? broadcasted_operand : lhs; lhs = (lhs == operand_to_broadcast) ? broadcasted_operand : lhs;
rhs = (rhs == operand_to_broadcast) ? broadcasted_operand : rhs; rhs = (rhs == operand_to_broadcast) ? broadcasted_operand : rhs;
} }
hlo_instruction = hlo_instruction = add_instruction(HloInstruction::CreateBinary(
hlo_builder_.AddInstruction(HloInstruction::CreateBinary( request.output_shape(), hlo_opcode, lhs, rhs));
request.output_shape(), hlo_opcode, lhs, rhs));
break; break;
} }
case OpRequest::kTraceRequest: { case OpRequest::kTraceRequest: {
const TraceRequest& trace_request = request.request().trace_request(); const TraceRequest& trace_request = request.request().trace_request();
HloInstruction* operand = Visit(trace_request.operand(), visited); HloInstruction* operand = Visit(trace_request.operand(), visited);
hlo_instruction = hlo_builder_.AddInstruction( hlo_instruction = add_instruction(
HloInstruction::CreateTrace(trace_request.tag(), operand)); HloInstruction::CreateTrace(trace_request.tag(), operand));
operand->set_tracing(hlo_instruction); operand->set_tracing(hlo_instruction);
break; break;
@ -2372,7 +2357,7 @@ HloInstruction* ComputationLowerer::Visit(
case OpRequest::kSendRequest: { case OpRequest::kSendRequest: {
const SendRequest& send_request = request.request().send_request(); const SendRequest& send_request = request.request().send_request();
HloInstruction* operand = Visit(send_request.operand(), visited); HloInstruction* operand = Visit(send_request.operand(), visited);
hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateSend( hlo_instruction = add_instruction(HloInstruction::CreateSend(
operand, send_request.channel_handle().handle())); operand, send_request.channel_handle().handle()));
break; break;
} }
@ -2383,7 +2368,6 @@ HloInstruction* ComputationLowerer::Visit(
default: default:
LOG(FATAL) << "Unexpected request type: " << request.request().op_case(); LOG(FATAL) << "Unexpected request type: " << request.request().op_case();
} }
hlo_instruction->set_metadata(request.request().metadata());
(*visited)[handle.handle()] = hlo_instruction; (*visited)[handle.handle()] = hlo_instruction;
return hlo_instruction; return hlo_instruction;
} }

View File

@ -59,6 +59,9 @@ TEST_F(UserComputationTest, SimpleComputation) {
param_request.set_name("param0"); param_request.set_name("param0");
TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle param_handle, TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle param_handle,
computation.AddParameterInstruction(param_request)); computation.AddParameterInstruction(param_request));
OpMetadata metadata;
metadata.set_op_name("meta");
TF_ASSERT_OK(computation.SetOpMetadata(param_handle, metadata));
OutfeedRequest outfeed_request; OutfeedRequest outfeed_request;
*outfeed_request.mutable_operand() = constant_handle; *outfeed_request.mutable_operand() = constant_handle;
@ -135,6 +138,8 @@ TEST_F(UserComputationTest, SimpleComputation) {
// The root of the instruction should be the parameter instruction (not the // The root of the instruction should be the parameter instruction (not the
// outfeed). // outfeed).
EXPECT_THAT(hlo_computation->root_instruction(), op::Parameter()); EXPECT_THAT(hlo_computation->root_instruction(), op::Parameter());
EXPECT_EQ(hlo_computation->root_instruction()->metadata().op_name(),
"meta");
} }
} }

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <initializer_list> #include <initializer_list>
#include <memory> #include <memory>
#include <random>
#include <string> #include <string>
#include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array2d.h"
@ -171,6 +172,36 @@ class LiteralTestUtil {
tensorflow::gtl::ArraySlice<int64> minor_to_major, tensorflow::gtl::ArraySlice<int64> minor_to_major,
const Literal& literal); const Literal& literal);
// Creates a literal with the supplied shape, and uses the provided value
// generator to populate the literal's values.
// Returns the new literal object, or an error Status if failed.
template <
PrimitiveType type,
typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
const Shape& shape,
const std::function<T(tensorflow::gtl::ArraySlice<int64>)>& generator);
// Creates a literal with the supplied shape, and initializes the literal
// values using a normal distribution with given mean and stddev standard
// deviation, and using the engine as entropy generator.
// Returns the new literal object, or an error Status if failed.
template <
PrimitiveType type, typename E,
typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
const Shape& shape, E* engine, T mean, T stddev);
// Creates a literal with the supplied shape, and initializes the literal
// values using a normal distribution with given mean and stddev standard
// deviation.
// Returns the new literal object, or an error Status if failed.
template <
PrimitiveType type,
typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
const Shape& shape, T mean, T stddev);
private: private:
TF_DISALLOW_COPY_AND_ASSIGN(LiteralTestUtil); TF_DISALLOW_COPY_AND_ASSIGN(LiteralTestUtil);
}; };
@ -270,6 +301,40 @@ template <typename NativeT>
ExpectNear(*LiteralUtil::CreateR4FromArray4D(expected), actual, error); ExpectNear(*LiteralUtil::CreateR4FromArray4D(expected), actual, error);
} }
template <PrimitiveType type, typename T>
/* static */ StatusOr<std::unique_ptr<Literal>>
LiteralTestUtil::CreateRandomLiteral(
const Shape& shape,
const std::function<T(tensorflow::gtl::ArraySlice<int64>)>& generator) {
using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
TF_RET_CHECK(shape.element_type() == type);
std::unique_ptr<Literal> literal = LiteralUtil::CreateFromShape(shape);
TF_RETURN_IF_ERROR(LiteralUtil::Populate<NativeT>(
literal.get(), [&](tensorflow::gtl::ArraySlice<int64> indexes) {
return generator(indexes);
}));
return std::move(literal);
}
template <PrimitiveType type, typename E, typename T>
/* static */ StatusOr<std::unique_ptr<Literal>>
LiteralTestUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean,
T stddev) {
using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
std::normal_distribution<NativeT> generator(mean, stddev);
return CreateRandomLiteral<type, NativeT>(
shape, [&](tensorflow::gtl::ArraySlice<int64> /*indexes*/) {
return generator(*engine);
});
}
template <PrimitiveType type, typename T>
/* static */ StatusOr<std::unique_ptr<Literal>>
LiteralTestUtil::CreateRandomLiteral(const Shape& shape, T mean, T stddev) {
std::minstd_rand0 engine;
return CreateRandomLiteral<type>(shape, &engine, mean, stddev);
}
} // namespace xla } // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_TESTS_LITERAL_TEST_UTIL_H_ #endif // TENSORFLOW_COMPILER_XLA_TESTS_LITERAL_TEST_UTIL_H_

View File

@ -741,6 +741,7 @@ class Dense(tf_core_layers.Dense, Layer):
self.constraints[self.kernel] = self.kernel_constraint self.constraints[self.kernel] = self.kernel_constraint
if self.use_bias and self.bias_constraint: if self.use_bias and self.bias_constraint:
self.constraints[self.bias] = self.bias_constraint self.constraints[self.bias] = self.bias_constraint
self.built = True
def get_config(self): def get_config(self):
config = { config = {

View File

@ -111,6 +111,7 @@ class _Merge(Layer):
self._reshape_required = False self._reshape_required = False
else: else:
self._reshape_required = True self._reshape_required = True
self.built = True
def call(self, inputs): def call(self, inputs):
if self._reshape_required: if self._reshape_required:
@ -302,6 +303,7 @@ class Concatenate(_Merge):
'inputs with matching shapes ' 'inputs with matching shapes '
'except for the concat axis. ' 'except for the concat axis. '
'Got inputs shapes: %s' % (input_shape)) 'Got inputs shapes: %s' % (input_shape))
self.built = True
def call(self, inputs): def call(self, inputs):
if not isinstance(inputs, list): if not isinstance(inputs, list):
@ -414,6 +416,7 @@ class Dot(_Merge):
raise ValueError('Dimension incompatibility ' raise ValueError('Dimension incompatibility '
'%s != %s. ' % (shape1[axes[0]], shape2[axes[1]]) + '%s != %s. ' % (shape1[axes[0]], shape2[axes[1]]) +
'Layer shapes: %s, %s' % (shape1, shape2)) 'Layer shapes: %s, %s' % (shape1, shape2))
self.built = True
def call(self, inputs): def call(self, inputs):
x1 = inputs[0] x1 = inputs[0]

View File

@ -166,6 +166,7 @@ class TimeDistributed(Wrapper):
self.layer.build(child_input_shape) self.layer.build(child_input_shape)
self.layer.built = True self.layer.built = True
super(TimeDistributed, self).build() super(TimeDistributed, self).build()
self.built = True
def _compute_output_shape(self, input_shape): def _compute_output_shape(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list() input_shape = tensor_shape.TensorShape(input_shape).as_list()

View File

@ -844,7 +844,7 @@ def convolution(inputs,
variable would be created and added the activations. Finally, if variable would be created and added the activations. Finally, if
`activation_fn` is not `None`, it is applied to the activations as well. `activation_fn` is not `None`, it is applied to the activations as well.
Performs a'trous convolution with input stride/dilation rate equal to `rate` Performs atrous convolution with input stride/dilation rate equal to `rate`
if a value > 1 for any dimension of `rate` is specified. In this case if a value > 1 for any dimension of `rate` is specified. In this case
`stride` values != 1 are not supported. `stride` values != 1 are not supported.
@ -870,7 +870,7 @@ def convolution(inputs,
"NCW". For N=2, the valid values are "NHWC" (default) and "NCHW". "NCW". For N=2, the valid values are "NHWC" (default) and "NCHW".
For N=3, the valid values are "NDHWC" (default) and "NCDHW". For N=3, the valid values are "NDHWC" (default) and "NCDHW".
rate: A sequence of N positive integers specifying the dilation rate to use rate: A sequence of N positive integers specifying the dilation rate to use
for a'trous convolution. Can be a single integer to specify the same for atrous convolution. Can be a single integer to specify the same
value for all spatial dimensions. Specifying any `rate` value != 1 is value for all spatial dimensions. Specifying any `rate` value != 1 is
incompatible with specifying any `stride` value != 1. incompatible with specifying any `stride` value != 1.
activation_fn: Activation function. The default value is a ReLU function. activation_fn: Activation function. The default value is a ReLU function.
@ -1865,7 +1865,7 @@ def separable_convolution2d(
depthwise convolution stride. Can be an int if both strides are the same. depthwise convolution stride. Can be an int if both strides are the same.
padding: One of 'VALID' or 'SAME'. padding: One of 'VALID' or 'SAME'.
rate: A list of length 2: [rate_height, rate_width], specifying the dilation rate: A list of length 2: [rate_height, rate_width], specifying the dilation
rates for a'trous convolution. Can be an int if both rates are the same. rates for atrous convolution. Can be an int if both rates are the same.
If any value is larger than one, then both stride values need to be one. If any value is larger than one, then both stride values need to be one.
activation_fn: Activation function. The default value is a ReLU function. activation_fn: Activation function. The default value is a ReLU function.
Explicitly set it to None to skip it and maintain a linear activation. Explicitly set it to None to skip it and maintain a linear activation.

View File

@ -966,7 +966,8 @@ class BaseEstimator(
saver.Saver( saver.Saver(
sharded=True, sharded=True,
max_to_keep=self._config.keep_checkpoint_max, max_to_keep=self._config.keep_checkpoint_max,
defer_build=True)) defer_build=True,
save_relative_paths=True))
chief_hooks = [] chief_hooks = []
if (self._config.save_checkpoints_secs or if (self._config.save_checkpoints_secs or

View File

@ -28,6 +28,8 @@ import numpy as np
import six import six
from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin
from google.protobuf import text_format
from tensorflow.contrib import learn from tensorflow.contrib import learn
from tensorflow.contrib import lookup from tensorflow.contrib import lookup
from tensorflow.contrib.framework.python.ops import variables from tensorflow.contrib.framework.python.ops import variables
@ -50,6 +52,7 @@ from tensorflow.python.client import session as session_lib
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
@ -61,6 +64,7 @@ from tensorflow.python.platform import test
from tensorflow.python.saved_model import loader from tensorflow.python.saved_model import loader
from tensorflow.python.saved_model import tag_constants from tensorflow.python.saved_model import tag_constants
from tensorflow.python.training import basic_session_run_hooks from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import checkpoint_state_pb2
from tensorflow.python.training import input as input_lib from tensorflow.python.training import input as input_lib
from tensorflow.python.training import monitored_session from tensorflow.python.training import monitored_session
from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import saver as saver_lib
@ -674,6 +678,38 @@ class EstimatorTest(test.TestCase):
metrics={'MSE': metric_ops.streaming_mean_squared_error}) metrics={'MSE': metric_ops.streaming_mean_squared_error})
self.assertLess(scores3['MSE'], scores['MSE']) self.assertLess(scores3['MSE'], scores['MSE'])
def test_checkpoint_contains_relative_paths(self):
tmpdir = tempfile.mkdtemp()
est = estimator.Estimator(
model_dir=tmpdir,
model_fn=linear_model_fn_with_model_fn_ops)
est.fit(input_fn=boston_input_fn, steps=5)
checkpoint_file_content = file_io.read_file_to_string(
os.path.join(tmpdir, 'checkpoint'))
ckpt = checkpoint_state_pb2.CheckpointState()
text_format.Merge(checkpoint_file_content, ckpt)
self.assertEqual(ckpt.model_checkpoint_path, 'model.ckpt-5')
self.assertAllEqual(
['model.ckpt-1', 'model.ckpt-5'], ckpt.all_model_checkpoint_paths)
def test_train_save_copy_reload(self):
tmpdir = tempfile.mkdtemp()
model_dir1 = os.path.join(tmpdir, 'model_dir1')
est1 = estimator.Estimator(
model_dir=model_dir1,
model_fn=linear_model_fn_with_model_fn_ops)
est1.fit(input_fn=boston_input_fn, steps=5)
model_dir2 = os.path.join(tmpdir, 'model_dir2')
os.renames(model_dir1, model_dir2)
est2 = estimator.Estimator(
model_dir=model_dir2,
model_fn=linear_model_fn_with_model_fn_ops)
self.assertEqual(5, est2.get_variable_value('global_step'))
est2.fit(input_fn=boston_input_fn, steps=5)
self.assertEqual(10, est2.get_variable_value('global_step'))
def testEstimatorParams(self): def testEstimatorParams(self):
boston = base.load_boston() boston = base.load_boston()
est = estimator.SKCompat( est = estimator.SKCompat(

View File

@ -379,7 +379,12 @@ def multi_label_head(n_classes,
loss_fn=None): loss_fn=None):
"""Creates a Head for multi label classification. """Creates a Head for multi label classification.
The Head uses sigmoid cross entropy loss. Multi-label classification handles the case where each example may have zero
or more associated labels, from a discrete set. This is distinct from
`multi_class_head` which has exactly one label from a discrete set.
This head by default uses sigmoid cross entropy loss, which expects as input
a multi-hot tensor of shape `(batch_size, num_classes)`.
Args: Args:
n_classes: Integer, number of classes, must be >= 2 n_classes: Integer, number of classes, must be >= 2

View File

@ -28,6 +28,7 @@ import six
from tensorflow.contrib.framework.python.framework import experimental from tensorflow.contrib.framework.python.framework import experimental
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.python.estimator import run_config as core_run_config from tensorflow.python.estimator import run_config as core_run_config
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib from tensorflow.python.training import server_lib
@ -260,10 +261,12 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig):
the feature. the feature.
evaluation_master: the master on which to perform evaluation. evaluation_master: the master on which to perform evaluation.
model_dir: directory where model parameters, graph etc are saved. If model_dir: directory where model parameters, graph etc are saved. If
`None`, see `Estimator` about where the model will be saved. `None`, will use `model_dir` property in `TF_CONFIG` environment
variable. If both are set, must have same value. If both are `None`, see
`Estimator` about where the model will be saved.
session_config: a ConfigProto used to set session parameters, or None. session_config: a ConfigProto used to set session parameters, or None.
Note - using this argument, it is easy to provide settings which break Note - using this argument, it is easy to provide settings which break
otherwise perfectly good models. Use with care. otherwise perfectly good models. Use with care.
""" """
super(RunConfig, self).__init__( super(RunConfig, self).__init__(
master=master, evaluation_master=evaluation_master) master=master, evaluation_master=evaluation_master)
@ -291,7 +294,7 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig):
# create Scaffold and Saver in their model_fn to set these. # create Scaffold and Saver in their model_fn to set these.
self._keep_checkpoint_max = keep_checkpoint_max self._keep_checkpoint_max = keep_checkpoint_max
self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours
self._model_dir = model_dir self._model_dir = _get_model_dir(model_dir)
def replace(self, **kwargs): def replace(self, **kwargs):
"""Returns a new instance of `RunConfig` replacing specified properties. """Returns a new instance of `RunConfig` replacing specified properties.
@ -434,3 +437,21 @@ def _get_master(cluster_spec, task_type, task_id):
# For backwards compatibility, we return empty string if task_type was # For backwards compatibility, we return empty string if task_type was
# not set (task_type did not previously exist). # not set (task_type did not previously exist).
return '' return ''
def _get_model_dir(model_dir):
"""Returns `model_dir` based user provided `model_dir` or `TF_CONFIG`."""
model_dir_in_tf_config = json.loads(
os.environ.get('TF_CONFIG') or '{}').get('model_dir', None)
if model_dir_in_tf_config is not None:
if model_dir is not None and model_dir_in_tf_config != model_dir:
raise ValueError(
'`model_dir` provided in RunConfig construct, if set, '
'must have the same value as the model_dir in TF_CONFIG. '
'model_dir: {}\nTF_CONFIG["model_dir"]: {}.\n'.format(
model_dir, model_dir_in_tf_config))
logging.info('Using model_dir in TF_CONFIG: %s', model_dir_in_tf_config)
return model_dir or model_dir_in_tf_config

View File

@ -223,6 +223,27 @@ class RunConfigTest(test.TestCase):
config = run_config_lib.RunConfig(model_dir=TEST_DIR) config = run_config_lib.RunConfig(model_dir=TEST_DIR)
self.assertEqual(TEST_DIR, config.model_dir) self.assertEqual(TEST_DIR, config.model_dir)
def test_model_dir_in_tf_config(self):
tf_config = {"model_dir": TEST_DIR}
with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}):
run_config = run_config_lib.RunConfig()
self.assertEqual(TEST_DIR, run_config.model_dir)
def test_model_dir_both_in_tf_config_and_constructor(self):
tf_config = {"model_dir": TEST_DIR}
with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}):
run_config = run_config_lib.RunConfig(model_dir=TEST_DIR)
self.assertEqual(TEST_DIR, run_config.model_dir)
def test_model_dir_fail_if_constructor_value_mismatch_tf_config(self):
tf_config = {"model_dir": TEST_DIR}
with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}):
with self.assertRaisesRegexp(
ValueError,
"`model_dir` provided in RunConfig .* must have "
"the same value .* in TF_CONFIG"):
run_config_lib.RunConfig(model_dir=TEST_DIR + "/sub_dir")
def test_replace(self): def test_replace(self):
config = run_config_lib.RunConfig( config = run_config_lib.RunConfig(
tf_random_seed=RANDOM_SEED, model_dir=TEST_DIR) tf_random_seed=RANDOM_SEED, model_dir=TEST_DIR)

View File

@ -65,12 +65,15 @@ class SquareLinearOperatorCompositionTest(
# feed_dict. # feed_dict.
matrices = sess.run(matrices) matrices = sess.run(matrices)
operator = linalg.LinearOperatorComposition( operator = linalg.LinearOperatorComposition(
[linalg.LinearOperatorFullMatrix(m_ph) for m_ph in matrices_ph]) [linalg.LinearOperatorFullMatrix(m_ph) for m_ph in matrices_ph],
is_square=True)
feed_dict = {m_ph: m for (m_ph, m) in zip(matrices_ph, matrices)} feed_dict = {m_ph: m for (m_ph, m) in zip(matrices_ph, matrices)}
else: else:
operator = linalg.LinearOperatorComposition( operator = linalg.LinearOperatorComposition(
[linalg.LinearOperatorFullMatrix(m) for m in matrices]) [linalg.LinearOperatorFullMatrix(m) for m in matrices])
feed_dict = None feed_dict = None
# Should be auto-set.
self.assertTrue(operator.is_square)
# Convert back to Tensor. Needed if use_placeholder, since then we have # Convert back to Tensor. Needed if use_placeholder, since then we have
# already evaluated each matrix to a numpy array. # already evaluated each matrix to a numpy array.

View File

@ -45,9 +45,10 @@ class SquareLinearOperatorFullMatrixTest(
# values are random and we want the same value used for both mat and # values are random and we want the same value used for both mat and
# feed_dict. # feed_dict.
matrix = matrix.eval() matrix = matrix.eval()
operator = linalg.LinearOperatorFullMatrix(matrix_ph) operator = linalg.LinearOperatorFullMatrix(matrix_ph, is_square=True)
feed_dict = {matrix_ph: matrix} feed_dict = {matrix_ph: matrix}
else: else:
# is_square should be auto-detected here.
operator = linalg.LinearOperatorFullMatrix(matrix) operator = linalg.LinearOperatorFullMatrix(matrix)
feed_dict = None feed_dict = None
@ -68,6 +69,8 @@ class SquareLinearOperatorFullMatrixTest(
self.assertTrue(operator.is_positive_definite) self.assertTrue(operator.is_positive_definite)
self.assertTrue(operator.is_non_singular) self.assertTrue(operator.is_non_singular)
self.assertFalse(operator.is_self_adjoint) self.assertFalse(operator.is_self_adjoint)
# Auto-detected.
self.assertTrue(operator.is_square)
class SquareLinearOperatorFullMatrixSymmetricPositiveDefiniteTest( class SquareLinearOperatorFullMatrixSymmetricPositiveDefiniteTest(
@ -104,6 +107,7 @@ class SquareLinearOperatorFullMatrixSymmetricPositiveDefiniteTest(
# values are random and we want the same value used for both mat and # values are random and we want the same value used for both mat and
# feed_dict. # feed_dict.
matrix = matrix.eval() matrix = matrix.eval()
# is_square is auto-set because of self_adjoint/pd.
operator = linalg.LinearOperatorFullMatrix( operator = linalg.LinearOperatorFullMatrix(
matrix_ph, is_self_adjoint=True, is_positive_definite=True) matrix_ph, is_self_adjoint=True, is_positive_definite=True)
feed_dict = {matrix_ph: matrix} feed_dict = {matrix_ph: matrix}
@ -129,7 +133,8 @@ class SquareLinearOperatorFullMatrixSymmetricPositiveDefiniteTest(
# Should be auto-set # Should be auto-set
self.assertTrue(operator.is_non_singular) self.assertTrue(operator.is_non_singular)
self.assertTrue(operator._is_spd) self.assertTrue(operator._can_use_cholesky)
self.assertTrue(operator.is_square)
class NonSquareLinearOperatorFullMatrixTest( class NonSquareLinearOperatorFullMatrixTest(
@ -157,16 +162,14 @@ class NonSquareLinearOperatorFullMatrixTest(
return operator, mat, feed_dict return operator, mat, feed_dict
def test_is_x_flags(self): def test_is_x_flags(self):
# Matrix with two positive eigenvalues. matrix = [[3., 2., 1.], [1., 1., 1.]]
matrix = [[3., 0.], [1., 1.]]
operator = linalg.LinearOperatorFullMatrix( operator = linalg.LinearOperatorFullMatrix(
matrix, matrix,
is_positive_definite=True,
is_non_singular=True,
is_self_adjoint=False) is_self_adjoint=False)
self.assertTrue(operator.is_positive_definite) self.assertEqual(operator.is_positive_definite, None)
self.assertTrue(operator.is_non_singular) self.assertEqual(operator.is_non_singular, None)
self.assertFalse(operator.is_self_adjoint) self.assertFalse(operator.is_self_adjoint)
self.assertFalse(operator.is_square)
def test_matrix_must_have_at_least_two_dims_or_raises(self): def test_matrix_must_have_at_least_two_dims_or_raises(self):
with self.assertRaisesRegexp(ValueError, "at least 2 dimensions"): with self.assertRaisesRegexp(ValueError, "at least 2 dimensions"):

View File

@ -54,6 +54,9 @@ class LinearOperatorShape(linalg.LinearOperator):
def _shape_tensor(self): def _shape_tensor(self):
return constant_op.constant(self._stored_shape, dtype=dtypes.int32) return constant_op.constant(self._stored_shape, dtype=dtypes.int32)
def _apply(self):
raise NotImplementedError("Not needed for this test.")
class LinearOperatorApplyOnly(linalg.LinearOperator): class LinearOperatorApplyOnly(linalg.LinearOperator):
"""LinearOperator that simply wraps a [batch] matrix and implements apply.""" """LinearOperator that simply wraps a [batch] matrix and implements apply."""

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import abc
import contextlib import contextlib
from tensorflow.contrib import framework as contrib_framework from tensorflow.contrib import framework as contrib_framework
@ -25,6 +26,7 @@ from tensorflow.contrib.linalg.python.ops import linear_operator_util
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
__all__ = ["LinearOperator"] __all__ = ["LinearOperator"]
@ -50,11 +52,9 @@ class LinearOperator(object):
#### Performance contract #### Performance contract
Subclasses should implement a method only if it can be done with a reasonable Subclasses should only implement the assert methods
performance increase over generic dense operations, either in time, parallel (e.g. `assert_non_singular`) if they can be done in less than `O(N^3)`
scalability, or memory usage. For example, if the determinant can only be time.
computed using `tf.matrix_determinant(self.to_dense())`, then determinants
should not be implemented.
Class docstrings should contain an explanation of computational complexity. Class docstrings should contain an explanation of computational complexity.
Since this is a high-performance library, attention should be paid to detail, Since this is a high-performance library, attention should be paid to detail,
@ -100,7 +100,7 @@ class LinearOperator(object):
operator.shape() operator.shape()
==> [2, 4, 4] ==> [2, 4, 4]
operator.log_determinant() operator.log_abs_determinant()
==> Shape [2] Tensor ==> Shape [2] Tensor
x = ... Shape [2, 4, 5] Tensor x = ... Shape [2, 4, 5] Tensor
@ -131,6 +131,7 @@ class LinearOperator(object):
* If `is_X == None` (the default), callers should have no expectation either * If `is_X == None` (the default), callers should have no expectation either
way. way.
""" """
__metaclass__ = abc.ABCMeta
def __init__(self, def __init__(self,
dtype, dtype,
@ -167,17 +168,23 @@ class LinearOperator(object):
ValueError: If hints are set incorrectly. ValueError: If hints are set incorrectly.
""" """
# Check and auto-set flags. # Check and auto-set flags.
if is_square is False:
if is_non_singular or is_positive_definite:
raise ValueError(
"A non-singular or positive definite operator is always square.")
self._is_square_set_by_user = is_square
if is_positive_definite: if is_positive_definite:
if is_non_singular is False: if is_non_singular is False:
raise ValueError("A positive definite matrix is always non-singular.") raise ValueError("A positive definite matrix is always non-singular.")
is_non_singular = True is_non_singular = True
if is_non_singular:
if is_square is False:
raise ValueError("A non-singular matrix is always square.")
is_square = True
if is_self_adjoint:
if is_square is False:
raise ValueError("A self-adjoint matrix is always square.")
is_square = True
self._is_square_set_or_implied_by_hints = is_square
graph_parents = [] if graph_parents is None else graph_parents graph_parents = [] if graph_parents is None else graph_parents
for i, t in enumerate(graph_parents): for i, t in enumerate(graph_parents):
if t is None or not contrib_framework.is_tensor(t): if t is None or not contrib_framework.is_tensor(t):
@ -239,15 +246,16 @@ class LinearOperator(object):
"""Return `True/False` depending on if this operator is square.""" """Return `True/False` depending on if this operator is square."""
# Static checks done after __init__. Why? Because domain/range dimension # Static checks done after __init__. Why? Because domain/range dimension
# sometimes requires lots of work done in the derived class after init. # sometimes requires lots of work done in the derived class after init.
static_square_check = self.domain_dimension == self.range_dimension auto_square_check = self.domain_dimension == self.range_dimension
if self._is_square_set_by_user is False and static_square_check: if self._is_square_set_or_implied_by_hints is False and auto_square_check:
raise ValueError( raise ValueError(
"User set is_square hint to False, but the operator was square.") "User set is_square hint to False, but the operator was square.")
if self._is_square_set_by_user is None: if self._is_square_set_or_implied_by_hints is None:
return static_square_check return auto_square_check
return self._is_square_set_by_user return self._is_square_set_or_implied_by_hints
@abc.abstractmethod
def _shape(self): def _shape(self):
# Write this in derived class to enable all static shape methods. # Write this in derived class to enable all static shape methods.
raise NotImplementedError("_shape is not implemented.") raise NotImplementedError("_shape is not implemented.")
@ -265,6 +273,7 @@ class LinearOperator(object):
""" """
return self._shape() return self._shape()
@abc.abstractmethod
def _shape_tensor(self): def _shape_tensor(self):
raise NotImplementedError("_shape_tensor is not implemented.") raise NotImplementedError("_shape_tensor is not implemented.")
@ -367,8 +376,7 @@ class LinearOperator(object):
self._cached_tensor_rank_tensor = ops.convert_to_tensor( self._cached_tensor_rank_tensor = ops.convert_to_tensor(
self.tensor_rank) self.tensor_rank)
else: else:
self._cached_tensor_rank_tensor = array_ops.size( self._cached_tensor_rank_tensor = array_ops.size(self.shape_tensor())
self.shape_tensor())
return self._cached_tensor_rank_tensor return self._cached_tensor_rank_tensor
@property @property
@ -486,9 +494,10 @@ class LinearOperator(object):
"""Check that arg.dtype == self.dtype.""" """Check that arg.dtype == self.dtype."""
if arg.dtype != self.dtype: if arg.dtype != self.dtype:
raise TypeError( raise TypeError(
"Expected argument to have dtype %s. Found: %s in tensor %s" "Expected argument to have dtype %s. Found: %s in tensor %s" %
% (self.dtype, arg.dtype, arg)) (self.dtype, arg.dtype, arg))
@abc.abstractmethod
def _apply(self, x, adjoint=False, adjoint_arg=False): def _apply(self, x, adjoint=False, adjoint_arg=False):
raise NotImplementedError("_apply is not implemented.") raise NotImplementedError("_apply is not implemented.")
@ -517,7 +526,9 @@ class LinearOperator(object):
return self._apply(x, adjoint=adjoint, adjoint_arg=adjoint_arg) return self._apply(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
def _determinant(self): def _determinant(self):
raise NotImplementedError("_det is not implemented.") if self._can_use_cholesky():
return math_ops.exp(self.log_abs_determinant())
return linalg_ops.matrix_determinant(self._matrix)
def determinant(self, name="det"): def determinant(self, name="det"):
"""Determinant for every batch member. """Determinant for every batch member.
@ -539,7 +550,11 @@ class LinearOperator(object):
return self._determinant() return self._determinant()
def _log_abs_determinant(self): def _log_abs_determinant(self):
raise NotImplementedError("_log_abs_det is not implemented.") if self._can_use_cholesky():
diag = array_ops.matrix_diag_part(self._get_cached_chol())
return 2 * math_ops.reduce_sum(math_ops.log(diag), reduction_indices=[-1])
abs_det = math_ops.abs(self.determinant())
return math_ops.log(abs_det)
def log_abs_determinant(self, name="log_abs_det"): def log_abs_determinant(self, name="log_abs_det"):
"""Log absolute value of determinant for every batch member. """Log absolute value of determinant for every batch member.
@ -561,13 +576,20 @@ class LinearOperator(object):
return self._log_abs_determinant() return self._log_abs_determinant()
def _solve(self, rhs, adjoint=False, adjoint_arg=False): def _solve(self, rhs, adjoint=False, adjoint_arg=False):
# Since this is an exact solve method for all rhs, this will only be if self.is_square is False:
# available for non-singular (batch) operators, in particular the operator raise NotImplementedError(
# must be square. "Solve is not yet implemented for non-square operators.")
raise NotImplementedError("_solve is not implemented.") rhs = linear_operator_util.matrix_adjoint(rhs) if adjoint_arg else rhs
if self._can_use_cholesky():
return linalg_ops.cholesky_solve(self._get_cached_chol(), rhs)
return linalg_ops.matrix_solve(
self._get_cached_dense_matrix(), rhs, adjoint=adjoint)
def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"): def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"):
"""Solve `R` (batch) systems of equations exactly: `A X = rhs`. """Solve `R` (batch) systems of equations with best effort: `A X = rhs`.
The solution may not be exact, and in this case it will be close in some
sense (see class docstring for details).
Examples: Examples:
@ -689,3 +711,20 @@ class LinearOperator(object):
x = ops.convert_to_tensor(x, name="x") x = ops.convert_to_tensor(x, name="x")
self._check_input_dtype(x) self._check_input_dtype(x)
return self._add_to_tensor(x) return self._add_to_tensor(x)
def _can_use_cholesky(self):
# TODO(langmore) Add complex types when tf.cholesky can use them.
return (not self.dtype.is_complex and self.is_self_adjoint and
self.is_positive_definite)
def _get_cached_dense_matrix(self):
if not hasattr(self, "_cached_dense_matrix"):
self._cached_dense_matrix = self.to_dense()
return self._cached_dense_matrix
def _get_cached_chol(self):
if not self._can_use_cholesky():
return None
if not hasattr(self, "_cached_chol"):
self._cached_chol = linalg_ops.cholesky(self._get_cached_dense_matrix())
return self._cached_chol

View File

@ -63,7 +63,7 @@ class LinearOperatorComposition(linear_operator.LinearOperator):
operator.shape operator.shape
==> [2, 2] ==> [2, 2]
operator.log_determinant() operator.log_abs_determinant()
==> scalar Tensor ==> scalar Tensor
x = ... Shape [2, 4] Tensor x = ... Shape [2, 4] Tensor
@ -96,7 +96,7 @@ class LinearOperatorComposition(linear_operator.LinearOperator):
#### Matrix property hints #### Matrix property hints
This `LinearOperator` is initialized with boolean flags of the form `is_X`, This `LinearOperator` is initialized with boolean flags of the form `is_X`,
for `X = non_singular, self_adjoint, positive_definite`. for `X = non_singular, self_adjoint, positive_definite, square`.
These have the following meaning These have the following meaning
* If `is_X == True`, callers should expect the operator to have the * If `is_X == True`, callers should expect the operator to have the
property `X`. This is a promise that should be fulfilled, but is *not* a property `X`. This is a promise that should be fulfilled, but is *not* a
@ -112,6 +112,7 @@ class LinearOperatorComposition(linear_operator.LinearOperator):
is_non_singular=None, is_non_singular=None,
is_self_adjoint=None, is_self_adjoint=None,
is_positive_definite=None, is_positive_definite=None,
is_square=None,
name=None): name=None):
r"""Initialize a `LinearOperatorComposition`. r"""Initialize a `LinearOperatorComposition`.
@ -132,6 +133,7 @@ class LinearOperatorComposition(linear_operator.LinearOperator):
self-adjoint to be positive-definite. See: self-adjoint to be positive-definite. See:
https://en.wikipedia.org/wiki/Positive-definite_matrix\ https://en.wikipedia.org/wiki/Positive-definite_matrix\
#Extension_for_non_symmetric_matrices #Extension_for_non_symmetric_matrices
is_square: Expect that this operator acts like square [batch] matrices.
name: A name for this `LinearOperator`. Default is the individual name: A name for this `LinearOperator`. Default is the individual
operators names joined with `_o_`. operators names joined with `_o_`.
@ -177,6 +179,7 @@ class LinearOperatorComposition(linear_operator.LinearOperator):
is_non_singular=is_non_singular, is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint, is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite, is_positive_definite=is_positive_definite,
is_square=is_square,
name=name) name=name)
@property @property

View File

@ -52,7 +52,7 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
operator.shape operator.shape
==> [2, 2] ==> [2, 2]
operator.log_determinant() operator.log_abs_determinant()
==> scalar Tensor ==> scalar Tensor
x = ... Shape [2, 4] Tensor x = ... Shape [2, 4] Tensor
@ -97,7 +97,7 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
#### Matrix property hints #### Matrix property hints
This `LinearOperator` is initialized with boolean flags of the form `is_X`, This `LinearOperator` is initialized with boolean flags of the form `is_X`,
for `X = non_singular, self_adjoint, positive_definite`. for `X = non_singular, self_adjoint, positive_definite, square`.
These have the following meaning These have the following meaning
* If `is_X == True`, callers should expect the operator to have the * If `is_X == True`, callers should expect the operator to have the
property `X`. This is a promise that should be fulfilled, but is *not* a property `X`. This is a promise that should be fulfilled, but is *not* a
@ -113,6 +113,7 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
is_non_singular=None, is_non_singular=None,
is_self_adjoint=None, is_self_adjoint=None,
is_positive_definite=None, is_positive_definite=None,
is_square=None,
name="LinearOperatorDiag"): name="LinearOperatorDiag"):
r"""Initialize a `LinearOperatorDiag`. r"""Initialize a `LinearOperatorDiag`.
@ -129,6 +130,7 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
self-adjoint to be positive-definite. See: self-adjoint to be positive-definite. See:
https://en.wikipedia.org/wiki/Positive-definite_matrix\ https://en.wikipedia.org/wiki/Positive-definite_matrix\
#Extension_for_non_symmetric_matrices #Extension_for_non_symmetric_matrices
is_square: Expect that this operator acts like square [batch] matrices.
name: A name for this `LinearOperator`. name: A name for this `LinearOperator`.
Raises: Raises:
@ -147,12 +149,17 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
else: else:
is_self_adjoint = True is_self_adjoint = True
if is_square is False:
raise ValueError("Only square diagonal operators currently supported.")
is_square = True
super(LinearOperatorDiag, self).__init__( super(LinearOperatorDiag, self).__init__(
dtype=self._diag.dtype, dtype=self._diag.dtype,
graph_parents=[self._diag], graph_parents=[self._diag],
is_non_singular=is_non_singular, is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint, is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite, is_positive_definite=is_positive_definite,
is_square=is_square,
name=name) name=name)
def _check_diag(self, diag): def _check_diag(self, diag):

View File

@ -19,11 +19,9 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.contrib.linalg.python.ops import linear_operator from tensorflow.contrib.linalg.python.ops import linear_operator
from tensorflow.contrib.linalg.python.ops import linear_operator_util
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
__all__ = ["LinearOperatorFullMatrix"] __all__ = ["LinearOperatorFullMatrix"]
@ -49,7 +47,7 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator):
operator.shape operator.shape
==> [2, 2] ==> [2, 2]
operator.log_determinant() operator.log_abs_determinant()
==> scalar Tensor ==> scalar Tensor
x = ... Shape [2, 4] Tensor x = ... Shape [2, 4] Tensor
@ -93,7 +91,7 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator):
#### Matrix property hints #### Matrix property hints
This `LinearOperator` is initialized with boolean flags of the form `is_X`, This `LinearOperator` is initialized with boolean flags of the form `is_X`,
for `X = non_singular, self_adjoint, positive_definite`. for `X = non_singular, self_adjoint, positive_definite, square`.
These have the following meaning These have the following meaning
* If `is_X == True`, callers should expect the operator to have the * If `is_X == True`, callers should expect the operator to have the
property `X`. This is a promise that should be fulfilled, but is *not* a property `X`. This is a promise that should be fulfilled, but is *not* a
@ -109,6 +107,7 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator):
is_non_singular=None, is_non_singular=None,
is_self_adjoint=None, is_self_adjoint=None,
is_positive_definite=None, is_positive_definite=None,
is_square=None,
name="LinearOperatorFullMatrix"): name="LinearOperatorFullMatrix"):
r"""Initialize a `LinearOperatorFullMatrix`. r"""Initialize a `LinearOperatorFullMatrix`.
@ -124,6 +123,7 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator):
self-adjoint to be positive-definite. See: self-adjoint to be positive-definite. See:
https://en.wikipedia.org/wiki/Positive-definite_matrix\ https://en.wikipedia.org/wiki/Positive-definite_matrix\
#Extension_for_non_symmetric_matrices #Extension_for_non_symmetric_matrices
is_square: Expect that this operator acts like square [batch] matrices.
name: A name for this `LinearOperator`. name: A name for this `LinearOperator`.
Raises: Raises:
@ -134,19 +134,13 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator):
self._matrix = ops.convert_to_tensor(matrix, name="matrix") self._matrix = ops.convert_to_tensor(matrix, name="matrix")
self._check_matrix(self._matrix) self._check_matrix(self._matrix)
# Special treatment for (real) Symmetric Positive Definite.
self._is_spd = (
(not self._matrix.dtype.is_complex)
and is_self_adjoint and is_positive_definite)
if self._is_spd:
self._chol = linalg_ops.cholesky(self._matrix)
super(LinearOperatorFullMatrix, self).__init__( super(LinearOperatorFullMatrix, self).__init__(
dtype=self._matrix.dtype, dtype=self._matrix.dtype,
graph_parents=[self._matrix], graph_parents=[self._matrix],
is_non_singular=is_non_singular, is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint, is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite, is_positive_definite=is_positive_definite,
is_square=is_square,
name=name) name=name)
def _check_matrix(self, matrix): def _check_matrix(self, matrix):
@ -177,23 +171,5 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator):
return math_ops.matmul( return math_ops.matmul(
self._matrix, x, adjoint_a=adjoint, adjoint_b=adjoint_arg) self._matrix, x, adjoint_a=adjoint, adjoint_b=adjoint_arg)
def _determinant(self):
if self._is_spd:
return math_ops.exp(self.log_abs_determinant())
return linalg_ops.matrix_determinant(self._matrix)
def _log_abs_determinant(self):
if self._is_spd:
diag = array_ops.matrix_diag_part(self._chol)
return 2 * math_ops.reduce_sum(math_ops.log(diag), reduction_indices=[-1])
abs_det = math_ops.abs(self.determinant())
return math_ops.log(abs_det)
def _solve(self, rhs, adjoint=False, adjoint_arg=False):
rhs = linear_operator_util.matrix_adjoint(rhs) if adjoint_arg else rhs
if self._is_spd:
return linalg_ops.cholesky_solve(self._chol, rhs)
return linalg_ops.matrix_solve(self._matrix, rhs, adjoint=adjoint)
def _to_dense(self): def _to_dense(self):
return self._matrix return self._matrix

View File

@ -112,7 +112,7 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity):
operator.shape operator.shape
==> [2, 2] ==> [2, 2]
operator.log_determinant() operator.log_abs_determinant()
==> 0. ==> 0.
x = ... Shape [2, 4] Tensor x = ... Shape [2, 4] Tensor
@ -180,7 +180,7 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity):
#### Matrix property hints #### Matrix property hints
This `LinearOperator` is initialized with boolean flags of the form `is_X`, This `LinearOperator` is initialized with boolean flags of the form `is_X`,
for `X = non_singular, self_adjoint, positive_definite`. for `X = non_singular, self_adjoint, positive_definite, square`.
These have the following meaning These have the following meaning
* If `is_X == True`, callers should expect the operator to have the * If `is_X == True`, callers should expect the operator to have the
property `X`. This is a promise that should be fulfilled, but is *not* a property `X`. This is a promise that should be fulfilled, but is *not* a
@ -198,6 +198,7 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity):
is_non_singular=True, is_non_singular=True,
is_self_adjoint=True, is_self_adjoint=True,
is_positive_definite=True, is_positive_definite=True,
is_square=True,
assert_proper_shapes=False, assert_proper_shapes=False,
name="LinearOperatorIdentity"): name="LinearOperatorIdentity"):
r"""Initialize a `LinearOperatorIdentity`. r"""Initialize a `LinearOperatorIdentity`.
@ -224,6 +225,7 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity):
self-adjoint to be positive-definite. See: self-adjoint to be positive-definite. See:
https://en.wikipedia.org/wiki/Positive-definite_matrix\ https://en.wikipedia.org/wiki/Positive-definite_matrix\
#Extension_for_non_symmetric_matrices #Extension_for_non_symmetric_matrices
is_square: Expect that this operator acts like square [batch] matrices.
assert_proper_shapes: Python `bool`. If `False`, only perform static assert_proper_shapes: Python `bool`. If `False`, only perform static
checks that initialization and method arguments have proper shape. checks that initialization and method arguments have proper shape.
If `True`, and static checks are inconclusive, add asserts to the graph. If `True`, and static checks are inconclusive, add asserts to the graph.
@ -248,12 +250,15 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity):
raise ValueError("An identity operator is always non-singular.") raise ValueError("An identity operator is always non-singular.")
if not is_positive_definite: if not is_positive_definite:
raise ValueError("An identity operator is always positive-definite.") raise ValueError("An identity operator is always positive-definite.")
if not is_square:
raise ValueError("An identity operator is always square.")
super(LinearOperatorIdentity, self).__init__( super(LinearOperatorIdentity, self).__init__(
dtype=dtype, dtype=dtype,
is_non_singular=is_non_singular, is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint, is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite, is_positive_definite=is_positive_definite,
is_square=is_square,
name=name) name=name)
self._num_rows = linear_operator_util.shape_tensor( self._num_rows = linear_operator_util.shape_tensor(
@ -459,7 +464,7 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity):
operator.shape operator.shape
==> [2, 2] ==> [2, 2]
operator.log_determinant() operator.log_abs_determinant()
==> 2 * Log[3] ==> 2 * Log[3]
x = ... Shape [2, 4] Tensor x = ... Shape [2, 4] Tensor
@ -510,7 +515,7 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity):
#### Matrix property hints #### Matrix property hints
This `LinearOperator` is initialized with boolean flags of the form `is_X`, This `LinearOperator` is initialized with boolean flags of the form `is_X`,
for `X = non_singular, self_adjoint, positive_definite`. for `X = non_singular, self_adjoint, positive_definite, square`.
These have the following meaning These have the following meaning
* If `is_X == True`, callers should expect the operator to have the * If `is_X == True`, callers should expect the operator to have the
property `X`. This is a promise that should be fulfilled, but is *not* a property `X`. This is a promise that should be fulfilled, but is *not* a
@ -527,6 +532,7 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity):
is_non_singular=None, is_non_singular=None,
is_self_adjoint=None, is_self_adjoint=None,
is_positive_definite=None, is_positive_definite=None,
is_square=True,
assert_proper_shapes=False, assert_proper_shapes=False,
name="LinearOperatorScaledIdentity"): name="LinearOperatorScaledIdentity"):
r"""Initialize a `LinearOperatorScaledIdentity`. r"""Initialize a `LinearOperatorScaledIdentity`.
@ -550,6 +556,7 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity):
self-adjoint to be positive-definite. See: self-adjoint to be positive-definite. See:
https://en.wikipedia.org/wiki/Positive-definite_matrix\ https://en.wikipedia.org/wiki/Positive-definite_matrix\
#Extension_for_non_symmetric_matrices #Extension_for_non_symmetric_matrices
is_square: Expect that this operator acts like square [batch] matrices.
assert_proper_shapes: Python `bool`. If `False`, only perform static assert_proper_shapes: Python `bool`. If `False`, only perform static
checks that initialization and method arguments have proper shape. checks that initialization and method arguments have proper shape.
If `True`, and static checks are inconclusive, add asserts to the graph. If `True`, and static checks are inconclusive, add asserts to the graph.
@ -561,6 +568,9 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity):
""" """
self._assert_proper_shapes = assert_proper_shapes self._assert_proper_shapes = assert_proper_shapes
if not is_square:
raise ValueError("A ScaledIdentity operator is always square.")
with ops.name_scope(name, values=[multiplier, num_rows]): with ops.name_scope(name, values=[multiplier, num_rows]):
self._multiplier = ops.convert_to_tensor(multiplier, name="multiplier") self._multiplier = ops.convert_to_tensor(multiplier, name="multiplier")
@ -569,6 +579,7 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity):
is_non_singular=is_non_singular, is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint, is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite, is_positive_definite=is_positive_definite,
is_square=is_square,
name=name) name=name)
# Shape [B1,...Bb, 1, 1] # Shape [B1,...Bb, 1, 1]

View File

@ -53,7 +53,7 @@ class LinearOperatorTriL(linear_operator.LinearOperator):
operator.shape operator.shape
==> [2, 2] ==> [2, 2]
operator.log_determinant() operator.log_abs_determinant()
==> scalar Tensor ==> scalar Tensor
x = ... Shape [2, 4] Tensor x = ... Shape [2, 4] Tensor
@ -90,7 +90,7 @@ class LinearOperatorTriL(linear_operator.LinearOperator):
#### Matrix property hints #### Matrix property hints
This `LinearOperator` is initialized with boolean flags of the form `is_X`, This `LinearOperator` is initialized with boolean flags of the form `is_X`,
for `X = non_singular, self_adjoint, positive_definite`. for `X = non_singular, self_adjoint, positive_definite, square`.
These have the following meaning These have the following meaning
* If `is_X == True`, callers should expect the operator to have the * If `is_X == True`, callers should expect the operator to have the
property `X`. This is a promise that should be fulfilled, but is *not* a property `X`. This is a promise that should be fulfilled, but is *not* a
@ -106,6 +106,7 @@ class LinearOperatorTriL(linear_operator.LinearOperator):
is_non_singular=None, is_non_singular=None,
is_self_adjoint=None, is_self_adjoint=None,
is_positive_definite=None, is_positive_definite=None,
is_square=None,
name="LinearOperatorTriL"): name="LinearOperatorTriL"):
r"""Initialize a `LinearOperatorTriL`. r"""Initialize a `LinearOperatorTriL`.
@ -126,12 +127,19 @@ class LinearOperatorTriL(linear_operator.LinearOperator):
self-adjoint to be positive-definite. See: self-adjoint to be positive-definite. See:
https://en.wikipedia.org/wiki/Positive-definite_matrix\ https://en.wikipedia.org/wiki/Positive-definite_matrix\
#Extension_for_non_symmetric_matrices #Extension_for_non_symmetric_matrices
is_square: Expect that this operator acts like square [batch] matrices.
name: A name for this `LinearOperator`. name: A name for this `LinearOperator`.
Raises: Raises:
TypeError: If `diag.dtype` is not an allowed type. TypeError: If `diag.dtype` is not an allowed type.
ValueError: If `is_square` is `False`.
""" """
if is_square is False:
raise ValueError(
"Only square lower triangular operators supported at this time.")
is_square = True
with ops.name_scope(name, values=[tril]): with ops.name_scope(name, values=[tril]):
self._tril = ops.convert_to_tensor(tril, name="tril") self._tril = ops.convert_to_tensor(tril, name="tril")
self._check_tril(self._tril) self._check_tril(self._tril)
@ -144,6 +152,7 @@ class LinearOperatorTriL(linear_operator.LinearOperator):
is_non_singular=is_non_singular, is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint, is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite, is_positive_definite=is_positive_definite,
is_square=is_square,
name=name) name=name)
def _check_tril(self, tril): def _check_tril(self, tril):

View File

@ -2417,6 +2417,9 @@ tf_cc_test(
":test_main", ":test_main",
":testlib", ":testlib",
"//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops",
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/cc:function_ops",
"//tensorflow/cc:functional_ops",
"//tensorflow/core/kernels:cast_op", "//tensorflow/core/kernels:cast_op",
"//tensorflow/core/kernels:cwise_op", "//tensorflow/core/kernels:cwise_op",
"//tensorflow/core/kernels:function_ops", "//tensorflow/core/kernels:function_ops",

View File

@ -1001,25 +1001,19 @@ string NewName(const Node* n, bool pretty) {
void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty) { void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty) {
// We visit nodes in forward topological sort order, which is a // We visit nodes in forward topological sort order, which is a
// possible execution order of the graph. // possible execution order of the graph.
std::vector<size_t> pending(g->num_node_ids());
std::deque<const Node*> ready;
for (const Node* n : g->nodes()) {
pending[n->id()] = n->in_edges().size();
if (pending[n->id()] == 0) ready.push_back(n);
}
gtl::InlinedVector<const Edge*, 4> inputs; gtl::InlinedVector<const Edge*, 4> inputs;
gdef->Clear(); gdef->Clear();
gdef->mutable_versions()->CopyFrom(g->versions()); gdef->mutable_versions()->CopyFrom(g->versions());
while (!ready.empty()) {
const Node* n = ready.front(); std::vector<Node*> start_nodes;
ready.pop_front(); for (Node* n : g->nodes()) {
for (const Edge* e : n->out_edges()) { if (n->out_edges().empty()) {
const Node* next = e->dst(); start_nodes.push_back(n);
if (--pending[next->id()] == 0) {
ready.push_back(next);
}
} }
if (!n->IsOp()) continue; }
ReverseDFSFrom(*g, start_nodes, nullptr, [gdef, pretty, &inputs](Node* n) {
if (!n->IsOp()) return;
NodeDef* ndef = gdef->add_node(); NodeDef* ndef = gdef->add_node();
ndef->set_name(NewName(n, pretty)); ndef->set_name(NewName(n, pretty));
ndef->set_op(n->type_string()); ndef->set_op(n->type_string());
@ -1054,7 +1048,7 @@ void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty) {
ndef->add_input(strings::StrCat(srcname, ":", e->src_output())); ndef->add_input(strings::StrCat(srcname, ":", e->src_output()));
} }
} }
} });
} }
string DebugString(const Graph* g) { string DebugString(const Graph* g) {

File diff suppressed because it is too large Load Diff

View File

@ -163,7 +163,7 @@ Status ShapeRefiner::UpdateNode(const Node* node, bool* refined) {
InferenceContext* c = iter->second.get(); InferenceContext* c = iter->second.get();
DCHECK_GE(e->dst_input(), 0); DCHECK_GE(e->dst_input(), 0);
if (node_context->set_input(e->dst_input(), c->output(e->src_output()))) { if (node_context->MergeInput(e->dst_input(), c->output(e->src_output()))) {
*refined = true; *refined = true;
} }
@ -174,7 +174,7 @@ Status ShapeRefiner::UpdateNode(const Node* node, bool* refined) {
e->dst_input(), c->output_handle_dtype(e->src_output()))) { e->dst_input(), c->output_handle_dtype(e->src_output()))) {
*refined = true; *refined = true;
} }
if (node_context->set_input_handle_shape( if (node_context->MergeInputHandleShape(
e->dst_input(), c->output_handle_shape(e->src_output()))) { e->dst_input(), c->output_handle_shape(e->src_output()))) {
*refined = true; *refined = true;
} }

View File

@ -400,16 +400,33 @@ void SetAttrValue(gtl::ArraySlice<NameAttrList> value, AttrValue* out) {
} }
} }
// Wrapper around protocol buffer serialization that requests deterministic
// serialization, in particular for Map fields, which serialize in a random
// order by default. Returns true on success.
template <typename T>
static bool DeterministicSerialization(const T& t, string* result) {
const int size = t.ByteSize();
*result = string(size, '\0');
::tensorflow::protobuf::io::ArrayOutputStream array_stream(&(*result)[0],
size);
::tensorflow::protobuf::io::CodedOutputStream output_stream(&array_stream);
output_stream.SetSerializationDeterministic(true);
t.SerializeWithCachedSizes(&output_stream);
return !output_stream.HadError() && size == output_stream.ByteCount();
}
bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b) { bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b) {
string a_str, b_str; string a_str, b_str;
a.SerializeToString(&a_str); DeterministicSerialization(a, &a_str);
b.SerializeToString(&b_str); DeterministicSerialization(b, &b_str);
// Note: it should be safe to compare proto serializations of the attr // Note: it should be safe to compare proto serializations of the attr
// values since at most one field should be set in each (indeed, it // values since at most one field should be set in each (indeed, it
// must be the same field if they are to compare equal). // must be the same field if they are to compare equal).
// Exception: there are multiple equivalent representations of // Exception: there are multiple equivalent representations of
// TensorProtos. So a return value of true implies a == b, but not the // TensorProtos. So a return value of true implies a == b, but not the
// converse. // converse.
// TODO(phawkins): this is incorrect for NameAttrList attributes that may
// contain nested AttrValue maps.
return a_str == b_str; return a_str == b_str;
} }

View File

@ -191,16 +191,18 @@ class InferenceContext {
return s; return s;
} }
// Set the shape of the input in position idx. This requires idx to be in the // Merge the stored shape of the input in position idx with the specified
// [0, num_inputs) range. Returns true iff the stored input shape has been // shape. This requires idx to be in the [0, num_inputs) range. If the merge
// updated with a different handle. // is successful and the new shape differs from the old one, store the new
bool set_input(int idx, ShapeHandle shape) { // shape and return true. Return false otherwise.
if (!inputs_[idx].SameHandle(shape)) { bool MergeInput(int idx, ShapeHandle shape) {
inputs_[idx] = shape; ShapeHandle new_shape;
return true; if (!Merge(inputs_[idx], shape, &new_shape).ok() ||
} else { inputs_[idx].SameHandle(new_shape)) {
return false; return false;
} }
inputs_[idx] = new_shape;
return true;
} }
ShapeHandle input(int64 idx) const { return inputs_[idx]; } ShapeHandle input(int64 idx) const { return inputs_[idx]; }
Status input(StringPiece input_name, std::vector<ShapeHandle>* output) const; Status input(StringPiece input_name, std::vector<ShapeHandle>* output) const;
@ -442,15 +444,18 @@ class InferenceContext {
// propagate that information. Output handle dtypes and shapes are ignored if // propagate that information. Output handle dtypes and shapes are ignored if
// the output tensor is not of type DT_RESOURCE. // the output tensor is not of type DT_RESOURCE.
// Set the shape corresponding to the resource in position idx. This requires // Merge the stored shape corresponding to the input handle in position idx
// idx to be in the [0, num_inputs) range. Returns true iff the stored shape // with the specified shape. This requires idx to be in the [0, num_inputs)
// has been updated with a different handle. // range. If the merge is successful and the new shape differs from the old
bool set_input_handle_shape(int idx, ShapeHandle shape) { // one, store the new shape and return true. Return false otherwise.
if (!input_handle_shape_[idx].SameHandle(shape)) { bool MergeInputHandleShape(int idx, ShapeHandle shape) {
input_handle_shape_[idx] = shape; ShapeHandle new_shape;
return true; if (!Merge(input_handle_shape_[idx], shape, &new_shape).ok() ||
input_handle_shape_[idx].SameHandle(new_shape)) {
return false;
} }
return false; input_handle_shape_[idx] = shape;
return true;
} }
// Set the type corresponding to the resource in position idx. This requires // Set the type corresponding to the resource in position idx. This requires
@ -468,15 +473,24 @@ class InferenceContext {
return input_handle_dtype_[idx]; return input_handle_dtype_[idx];
} }
// Set the shape corresponding to the resource in position idx. This requires // Merge the stored shape corresponding to the output handle in position idx
// idx to be in the [0, num_outputs) range. // with the specified shape. This requires idx to be in the [0, num_outputs)
// Returns true iff the stored shape has been updated with a different handle. // range. If the merge is successful and the new shape differs from the old
bool set_output_handle_shape(int idx, ShapeHandle shape) { // one, store the new shape and return true. Return false otherwise.
if (!output_handle_shape_[idx].SameHandle(shape)) {
output_handle_shape_[idx] = shape; bool MergeOutputHandleShape(int idx, ShapeHandle shape) {
return true; ShapeHandle new_shape;
if (!Merge(output_handle_shape_[idx], shape, &new_shape).ok() ||
output_handle_shape_[idx].SameHandle(new_shape)) {
return false;
} }
return false; output_handle_shape_[idx] = shape;
return true;
}
// Overwrite the shape corresponding to the output handle in position idx with
// the specified shape.
void set_output_handle_shape(int idx, ShapeHandle shape) {
output_handle_shape_[idx] = shape;
} }
// Set the type corresponding to the resource in position idx. This requires // Set the type corresponding to the resource in position idx. This requires

View File

@ -23,8 +23,8 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
void DFS(const Graph& g, std::function<void(Node*)> enter, void DFS(const Graph& g, const std::function<void(Node*)>& enter,
std::function<void(Node*)> leave) { const std::function<void(Node*)>& leave) {
// Stack of work to do. // Stack of work to do.
struct Work { struct Work {
Node* node; Node* node;
@ -61,15 +61,23 @@ void DFS(const Graph& g, std::function<void(Node*)> enter,
} }
} }
void ReverseDFS(const Graph& g, std::function<void(Node*)> enter, void ReverseDFS(const Graph& g, const std::function<void(Node*)>& enter,
std::function<void(Node*)> leave) { const std::function<void(Node*)>& leave) {
ReverseDFSFrom(g, {g.sink_node()}, enter, leave);
}
void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<Node*> start,
const std::function<void(Node*)>& enter,
const std::function<void(Node*)>& leave) {
// Stack of work to do. // Stack of work to do.
struct Work { struct Work {
Node* node; Node* node;
bool leave; // Are we entering or leaving n? bool leave; // Are we entering or leaving n?
}; };
std::vector<Work> stack; std::vector<Work> stack(start.size());
stack.push_back(Work{g.sink_node(), false}); for (int i = 0; i < start.size(); ++i) {
stack[i] = Work{start[i], false};
}
std::vector<bool> visited(g.num_node_ids(), false); std::vector<bool> visited(g.num_node_ids(), false);
while (!stack.empty()) { while (!stack.empty()) {

View File

@ -21,20 +21,28 @@ limitations under the License.
#include <vector> #include <vector>
#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
namespace tensorflow { namespace tensorflow {
// Perform a depth-first-search on g starting at the source node. // Perform a depth-first-search on g starting at the source node.
// If enter is not empty, calls enter(n) before visiting any children of n. // If enter is not empty, calls enter(n) before visiting any children of n.
// If leave is not empty, calls leave(n) after visiting all children of n. // If leave is not empty, calls leave(n) after visiting all children of n.
extern void DFS(const Graph& g, std::function<void(Node*)> enter, extern void DFS(const Graph& g, const std::function<void(Node*)>& enter,
std::function<void(Node*)> leave); const std::function<void(Node*)>& leave);
// Perform a reverse depth-first-search on g starting at the sink node. // Perform a reverse depth-first-search on g starting at the sink node.
// If enter is not empty, calls enter(n) before visiting any parents of n. // If enter is not empty, calls enter(n) before visiting any parents of n.
// If leave is not empty, calls leave(n) after visiting all parents of n. // If leave is not empty, calls leave(n) after visiting all parents of n.
extern void ReverseDFS(const Graph& g, std::function<void(Node*)> enter, extern void ReverseDFS(const Graph& g, const std::function<void(Node*)>& enter,
std::function<void(Node*)> leave); const std::function<void(Node*)>& leave);
// Perform a reverse depth-first-search on g starting at the 'start' nodes.
// If enter is not empty, calls enter(n) before visiting any parents of n.
// If leave is not empty, calls leave(n) after visiting all parents of n.
extern void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<Node*> start,
const std::function<void(Node*)>& enter,
const std::function<void(Node*)>& leave);
// Stores in *order the post-order numbering of all nodes // Stores in *order the post-order numbering of all nodes
// in graph found via a depth first search starting at the source node. // in graph found via a depth first search starting at the source node.

View File

@ -90,6 +90,23 @@ cc_test(
], ],
) )
cc_library(
name = "robust_stats",
srcs = ["robust_stats.cc"],
hdrs = ["robust_stats.h"],
visibility = ["//visibility:public"],
)
cc_test(
name = "robust_stats_test",
srcs = ["robust_stats_test.cc"],
deps = [
":robust_stats",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
cc_library( cc_library(
name = "utils", name = "utils",
srcs = ["utils.cc"], srcs = ["utils.cc"],
@ -116,3 +133,37 @@ cc_library(
"//tensorflow/core:lib", "//tensorflow/core:lib",
], ],
) )
cc_library(
name = "virtual_scheduler",
srcs = ["virtual_scheduler.cc"],
hdrs = ["virtual_scheduler.h"],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/costs:cost_estimator",
],
)
cc_library(
name = "measuring_cost_estimator",
srcs = ["measuring_cost_estimator.cc"],
hdrs = ["measuring_cost_estimator.h"],
visibility = ["//visibility:public"],
deps = [
":robust_stats",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:lib_proto_parsing",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:grappler_item_builder",
"//tensorflow/core/grappler/clusters:cluster",
"//tensorflow/core/grappler/costs:cost_estimator",
"//tensorflow/core/kernels:ops_util",
],
)

View File

@ -84,8 +84,8 @@ Status GraphProperties::InferStatically() {
} }
} }
} }
if (qctx->set_output_handle_dtype(0, queue_type) || if (qctx->set_output_handle_dtype(0, queue_type) |
qctx->set_output_handle_shape(0, queue_shp)) { qctx->MergeOutputHandleShape(0, queue_shp)) {
new_shapes.push(qnode); new_shapes.push(qnode);
} }
} }

View File

@ -177,10 +177,14 @@ TEST_F(GraphPropertiesTest, Queues) {
auto dequeue2 = auto dequeue2 =
ops::QueueDequeue(root.WithOpName("Dequeue2"), q2, {DataType::DT_FLOAT}); ops::QueueDequeue(root.WithOpName("Dequeue2"), q2, {DataType::DT_FLOAT});
// Create a queue that feeds itself.
auto q3 = auto q3 =
ops::RandomShuffleQueue(root.WithOpName("Queue3"), {DataType::DT_FLOAT}); ops::RandomShuffleQueue(root.WithOpName("Queue3"), {DataType::DT_FLOAT});
auto dequeue3 = auto dequeue3 =
ops::QueueDequeue(root.WithOpName("Dequeue3"), q3, {DataType::DT_FLOAT}); ops::QueueDequeue(root.WithOpName("Dequeue3"), q3, {DataType::DT_FLOAT});
auto merge3 = ops::Merge(root.WithOpName("Merge3"), {dequeue3[0], square2});
auto enqueue3 =
ops::QueueEnqueue(root.WithOpName("Enqueue3"), q3, {merge3.output});
auto q4 = auto q4 =
ops::RandomShuffleQueue(root.WithOpName("Queue4"), {DataType::DT_FLOAT}); ops::RandomShuffleQueue(root.WithOpName("Queue4"), {DataType::DT_FLOAT});
@ -227,6 +231,229 @@ TEST_F(GraphPropertiesTest, Queues) {
EXPECT_EQ(7, prop4.shape().dim(1).size()); EXPECT_EQ(7, prop4.shape().dim(1).size());
} }
TEST_F(GraphPropertiesTest, Loops) {
// Test graph produced in python using:
/*
with tf.Graph().as_default():
i = tf.constant(0)
c = lambda i: tf.less(i, 10)
b = lambda i: tf.add(i, 1)
r = tf.while_loop(c, b, [i])
with open('/tmp/graph.txt', 'w') as f:
f.write(str(tf.get_default_graph().as_graph_def()))
*/
const string gdef_ascii = R"EOF(
node {
name: "Const"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: 0
}
}
}
}
node {
name: "while/Enter"
op: "Enter"
input: "Const"
attr {
key: "T"
value {
type: DT_INT32
}
}
attr {
key: "frame_name"
value {
s: "while/while/"
}
}
attr {
key: "is_constant"
value {
b: false
}
}
attr {
key: "parallel_iterations"
value {
i: 10
}
}
}
node {
name: "while/Merge"
op: "Merge"
input: "while/Enter"
input: "while/NextIteration"
attr {
key: "N"
value {
i: 2
}
}
attr {
key: "T"
value {
type: DT_INT32
}
}
}
node {
name: "while/Less/y"
op: "Const"
input: "^while/Merge"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: 10
}
}
}
}
node {
name: "while/Less"
op: "Less"
input: "while/Merge"
input: "while/Less/y"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
node {
name: "while/LoopCond"
op: "LoopCond"
input: "while/Less"
}
node {
name: "while/Switch"
op: "Switch"
input: "while/Merge"
input: "while/LoopCond"
attr {
key: "T"
value {
type: DT_INT32
}
}
attr {
key: "_class"
value {
list {
s: "loc:@while/Merge"
}
}
}
}
node {
name: "while/Identity"
op: "Identity"
input: "while/Switch:1"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
node {
name: "while/Add/y"
op: "Const"
input: "^while/Identity"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: 1
}
}
}
}
node {
name: "while/Add"
op: "Add"
input: "while/Identity"
input: "while/Add/y"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
node {
name: "while/NextIteration"
op: "NextIteration"
input: "while/Add"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
node {
name: "while/Exit"
op: "Exit"
input: "while/Switch"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
versions {
producer: 11
}
)EOF";
GrapplerItem item;
CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &item.graph));
GraphProperties properties(item);
TF_CHECK_OK(properties.InferStatically());
const auto props = properties.GetOutputProperties("while/Exit");
EXPECT_EQ(1, props.size());
const OpInfo::TensorProperties& prop = props[0];
EXPECT_EQ(DT_INT32, prop.dtype());
EXPECT_TRUE(prop.shape().unknown_rank());
}
} // namespace } // namespace
} // namespace grappler } // namespace grappler
} // namespace tensorflow } // namespace tensorflow

View File

@ -0,0 +1,133 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/costs/measuring_cost_estimator.h"
#include <limits>
#include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/costs/robust_stats.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/public/session.h"
namespace tensorflow {
namespace grappler {
MeasuringCostEstimator::MeasuringCostEstimator(Cluster* cluster,
int measurement_steps,
int measurement_threads)
: measurement_steps_(measurement_steps),
measurement_threads_(measurement_threads) {
CHECK_GE(measurement_steps, 1);
if (measurement_threads > 0) {
thread_pool_.reset(new thread::ThreadPool(
Env::Default(), SanitizeThreadSuffix("measurements"),
measurement_threads));
}
cluster_ = cluster;
}
Status MeasuringCostEstimator::Initialize(const GrapplerItem& item) {
feed_ = item.feed;
fetch_ = item.fetch;
return cluster_->Initialize(item);
}
Status MeasuringCostEstimator::PredictCosts(const GraphDef& optimized_graph,
CostGraphDef* cost_graph,
Costs* costs) const {
std::vector<double> times(measurement_steps_);
BlockingCounter barrier(measurement_steps_);
mutex status_mu;
Status status;
auto measurement_fn = [&](const int step) {
const Costs::MicroSeconds start = Env::Default()->NowMicros();
RunMetadata metadata;
const Status local_status =
cluster_->Run(optimized_graph, feed_, fetch_, &metadata);
{
mutex_lock lock(status_mu);
status.Update(local_status);
}
if (step < 0) {
// Discard the first iteration as it triggers the warmup, and therefore
// takes much longer than a normal step.
return;
}
if (!local_status.ok()) {
// Discard the data if the run wasn't sucessful.
barrier.DecrementCount();
return;
}
const Costs::MicroSeconds finish = Env::Default()->NowMicros();
const double time = (finish - start).count() * 1e3;
times[step] = time;
if (cost_graph && (step + 1 == measurement_steps_)) {
metadata.mutable_cost_graph()->Swap(cost_graph);
}
barrier.DecrementCount();
};
// Initialize the computation and warm up TensorFlow.
measurement_fn(-1);
if (!status.ok()) {
LOG(ERROR) << "Failed to run start measurements: "
<< status.error_message();
costs->execution_time = Costs::Duration::max();
return status;
}
// Run "measurement_steps_" and measure the time.
if (measurement_threads_ > 0) {
for (int i = 0; i < measurement_steps_; ++i) {
thread_pool_->Schedule([i, &measurement_fn]() { measurement_fn(i); });
}
barrier.Wait();
} else {
for (int i = 0; i < measurement_steps_ && status.ok(); ++i) {
measurement_fn(i);
}
}
if (!status.ok()) {
LOG(ERROR) << "Failed to measure graph performance: "
<< status.error_message();
costs->execution_time = Costs::Duration::max();
costs->max_execution_time = Costs::Duration::max();
costs->min_execution_time = 0;
return status;
}
// Compute the average time of the measure steps. Use Huber statistics
// to filter out outliers.
RobustStats stats(times);
costs->execution_time = Costs::Duration(stats.mean());
costs->max_execution_time = Costs::Duration(stats.hi());
costs->min_execution_time = Costs::Duration(stats.lo());
return Status::OK();
}
} // end namespace grappler
} // end namespace tensorflow

View File

@ -0,0 +1,76 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_GRAPPLER_COSTS_MEASURING_COST_ESTIMATOR_H_
#define TENSORFLOW_GRAPPLER_COSTS_MEASURING_COST_ESTIMATOR_H_
#include <string>
#include <utility>
#include <vector>
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/grappler/costs/cost_estimator.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
class CostGraphDef;
class GraphDef;
} // namespace tensorflow
namespace tensorflow {
namespace grappler {
class Cluster;
struct GrapplerItem;
// Estimate the cost of running a Grappler item by actually running the
// corresponding TensorFlow graph on the specified cluster and measuring the
// runtimes.
class MeasuringCostEstimator : public CostEstimator {
public:
// Run the model for measurement_steps to measure its average cost.
// When measurement_threads is greater than 0, use a threadpool of as many
// threads to run the measurements; otherwise, run them serially. Does not
// take ownership of cluster.
explicit MeasuringCostEstimator(Cluster* cluster, int measurement_steps,
int measurement_threads);
~MeasuringCostEstimator() override {}
// Initalizes the estimator for the specified grappler item.
// This implementation always returns OK.
Status Initialize(const GrapplerItem& item) override;
// Runs the optimized version of the graph on the cluster, measure
// the runtimes of each operation, and annotated the CostGraphDef
// with the corresponding measurements.
// Returns the average latency for the whole graph.
Status PredictCosts(const GraphDef& optimized_graph, CostGraphDef* cost_graph,
Costs* overall_cost) const override;
private:
Cluster* cluster_; // Not owned.
int measurement_steps_;
int measurement_threads_;
std::vector<std::pair<string, Tensor>> feed_;
std::vector<string> fetch_;
std::unique_ptr<thread::ThreadPool> thread_pool_;
};
} // end namespace grappler
} // end namespace tensorflow
#endif // TENSORFLOW_GRAPPLER_COSTS_MEASURING_COST_ESTIMATOR_H_

View File

@ -0,0 +1,152 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/costs/robust_stats.h"
#include <algorithm>
#include <cmath>
namespace tensorflow {
namespace grappler {
// Given a sorted vector of values, calculate the median.
// Returns 0 for an empty vector. Does not verify sortedness.
static double SortedMedian(const std::vector<double> &values) {
const int n = values.size();
if (n == 0) return 0.0;
if (n & 1) {
return values[n / 2];
} else {
return (values[n / 2] + values[n / 2 - 1]) / 2.0;
}
}
// Given a vector of values (sorted or not), calculate the median.
static double Median(std::vector<double> &&values) {
const size_t n = values.size();
if (n == 0) return 0;
const auto middle = values.begin() + (n / 2);
// Put the middle value in its place.
std::nth_element(values.begin(), middle, values.end());
if (n & 1) {
return *middle;
}
// Return the average of the two elements, the max_element lower than
// *middle is found between begin and middle as a post-cond of
// nth_element.
const auto lower_middle = std::max_element(values.begin(), middle);
// Preventing overflow. We know that '*lower_middle <= *middle'.
// If both are on oposite sides of zero, the sum won't overflow, otherwise
// the difference won't overflow.
if (*lower_middle <= 0 && *middle >= 0) {
return (*lower_middle + *middle) / 2;
}
return *lower_middle + (*middle - *lower_middle) / 2;
}
// Given a set of values, calculates the scaled Median Absolute Deviation (a
// robust approximation to the standard deviation). This is calculated as the
// median of the absolute deviations from the median, scaled by 1.4826. Its
// advantage over the standard deviation is that it is not (as) affected by
// outlier values. Returns a pair<median, mad>.
static std::pair<double, double> ScaledMedianAbsoluteDeviation(
const std::vector<double> &sorted_values) {
double median = SortedMedian(sorted_values);
// Next, we calculate the absolute deviations from the median,
// find the median of the resulting data, and scale by 1.4826.
std::vector<double> deviations;
deviations.reserve(sorted_values.size());
for (double d : sorted_values) {
deviations.push_back(std::abs(d - median));
}
double mad = Median(std::move(deviations)) * 1.4826;
return std::pair<double, double>(median, mad);
}
RobustStats::RobustStats(const std::vector<double> &values)
: RobustStats(std::vector<double>(values)) {}
RobustStats::RobustStats(std::vector<double> &&values) {
std::sort(values.begin(), values.end());
lo_ = values[0];
hi_ = values.back();
HuberMAD(values);
}
// Computes an updated mean using Huber's weighting function (values beyond
// the margin are weighted by margin / abs(value - mean).
double UpdateHuberMean(const std::vector<double> &sorted_values, double mean,
double margin) {
int num_within = 0;
double sum = 0.0;
for (double d : sorted_values) {
if (d < mean - margin) {
sum -= margin;
} else if (d > mean + margin) {
sum += margin;
} else {
sum += d;
++num_within;
}
}
// It is possible, for a set with an interquartile distance of 0, i.e., with
// more than half of the values at the median, to encounter the case where
// the Huber mean drifts slightly off the median and there are no values
// within the margin. In that case, just return the old mean, and the caller
// will quit.
if (num_within > 0) {
return sum / num_within;
} else {
return mean;
}
}
// Given a list of values, this approximates the stddev using the MAD and then
// uses it to compute a Huber robust mean (sandwich mean). A margin of
// c*stddev is defined around the current mean, and values are weighted by
// margin / abs(value - mean) if outside the margin, or 1 if inside. This
// computes the mean iteratively, because each time it changes the margin
// shifts a bit. It typically settles very quickly, but it's possible for it
// to be unstable. We limit it to 10 iterations.
//
void RobustStats::HuberMAD(const std::vector<double> &sorted_values) {
const std::pair<double, double> median_mad =
ScaledMedianAbsoluteDeviation(sorted_values);
mean_ = median_mad.first;
stddev_ = median_mad.second;
// c = 1.345 is the commonly used cutoff with 95% efficiency at the normal.
// We're using c = 1.5 to be a little more conservative, and because that's
// the default in S-plus.
// TODO(dehnert): Specialize Stats for integral types so we don't implement
// methods that don't make sense.
const double c = 1.5;
const double margin = c * stddev_;
// Iterate 10 times, or until the Huber mean stabilizes.
// If the margin is zero, we don't want mean to drift from the median.
if (margin > 0.0) {
for (int k = 0; k < 10; ++k) {
double old_mean = mean_;
mean_ = UpdateHuberMean(sorted_values, mean_, margin);
if (mean_ == old_mean) break;
}
}
}
} // namespace grappler
} // namespace tensorflow

View File

@ -0,0 +1,42 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_GRAPPLER_COSTS_ROBUST_STATS_H_
#define TENSORFLOW_GRAPPLER_COSTS_ROBUST_STATS_H_
#include <vector>
namespace tensorflow {
namespace grappler {
class RobustStats {
public:
RobustStats(const std::vector<double>& values);
RobustStats(std::vector<double>&& values);
double lo() const { return lo_; }
double hi() const { return hi_; }
double mean() const { return mean_; }
private:
void HuberMAD(const std::vector<double>& values);
double lo_;
double hi_;
double mean_;
double stddev_;
};
} // namespace grappler
} // namespace tensorflow
#endif // TENSORFLOW_GRAPPLER_COSTS_ROBUST_STATS_H_

View File

@ -0,0 +1,63 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/costs/robust_stats.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace grappler {
namespace {
class RobustStatsTest : public ::testing::Test {
public:
void SetUp() override {
for (double d = 1.0; d <= 5.0; d += 1.0) {
values1_.push_back(5.0 - d);
values1_.push_back(5.0 + d);
values2_.push_back(25.0 - 2 * d);
values2_.push_back(25.0 + 2 * d);
values3_.push_back(-3.0 - d);
values3_.push_back(-3.0 + d);
}
values1_.push_back(5.0); // Odd # elements, mean is 5.0
values3_.push_back(197.0);
values3_.push_back(-203.0); // Even # elements, mean is -3.0
}
std::vector<double> values1_;
std::vector<double> values2_;
std::vector<double> values3_;
};
TEST_F(RobustStatsTest, Simple) {
RobustStats s1(values1_);
EXPECT_EQ(5.0, s1.mean());
EXPECT_EQ(0.0, s1.lo());
EXPECT_EQ(10.0, s1.hi());
RobustStats s2(values2_);
EXPECT_EQ(25.0, s2.mean());
EXPECT_EQ(15.0, s2.lo());
EXPECT_EQ(35.0, s2.hi());
RobustStats s3(values3_);
EXPECT_EQ(-3.0, s3.mean());
EXPECT_EQ(-203.0, s3.lo());
EXPECT_EQ(197.0, s3.hi());
}
} // namespace
} // namespace grappler
} // namespace tensorflow

View File

@ -0,0 +1,215 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/costs/virtual_scheduler.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/grappler/utils.h"
namespace tensorflow {
namespace grappler {
namespace {
Costs CombineCosts(const Costs& left, const Costs& right) {
CHECK_NE(left.max_memory, kMemoryUnknown);
CHECK_NE(left.max_per_op_buffers, kMemoryUnknown);
CHECK_NE(left.max_per_op_streaming, kMemoryUnknown);
Costs result = left;
result.execution_time += right.execution_time;
if (right.max_memory != kMemoryUnknown) {
result.max_memory += right.max_memory;
}
if (right.max_per_op_buffers != kMemoryUnknown) {
result.max_per_op_buffers =
std::max(left.max_per_op_buffers, right.max_per_op_buffers);
}
if (right.max_per_op_streaming != kMemoryUnknown) {
result.max_per_op_streaming =
std::max(left.max_per_op_streaming, right.max_per_op_streaming);
}
VLOG(2) << "costs execution_time=" << result.execution_time.count()
<< " max_memory=" << result.max_memory
<< " max_per_op_buffers=" << result.max_per_op_buffers
<< " max_per_op_streaming=" << result.max_per_op_streaming;
return result;
}
} // namespace
VirtualScheduler::VirtualScheduler(const GraphDef& graph,
const std::vector<string>& fetch_nodes)
: graph_costs_(Costs::ZeroCosts()),
// TODO(dyoon): Use a better way than FIFO.
ready_nodes_(new FIFOManager()) {
// First, get the nodes that would run to output fetch_nodes.
std::vector<const NodeDef*> nodes =
ComputeTransitiveFanin(graph, fetch_nodes);
// TODO(dyoon): this is a bit inefficient as name_to_node is already built in
// ComputeTransitiveFanin().
std::unordered_map<string, const NodeDef*> name_to_node;
for (const auto& node : graph.node()) {
name_to_node[node.name()] = &node;
}
// Build node_map.
for (const auto* node : nodes) {
auto& node_state = GetNodeStateOrCreateIt(node);
// TODO(dyoon): add SendRecv considering devices and control dependency.
for (const string& input : node->input()) {
const NodeDef* in = name_to_node[NodeName(input)];
CHECK(in);
node_state.inputs.push_back(in);
auto& input_node_state = GetNodeStateOrCreateIt(in);
input_node_state.outputs.push_back(node);
}
if (node->input().empty()) {
node_state.time_ready =
Costs::Duration(); // Node without input: ready at time 0.
ready_nodes_->AddNode(node);
}
}
}
const NodeDef* VirtualScheduler::GetCurrNode() const {
return ready_nodes_->GetCurrNode();
}
NodeState& VirtualScheduler::GetNodeStateOrCreateIt(const NodeDef* node) {
auto it = node_map_.find(node);
if (it == node_map_.end()) {
it = node_map_.emplace(node, NodeState()).first;
}
return it->second;
}
bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
// Update graph_costs_ and per-op costs.
graph_costs_ = CombineCosts(graph_costs_, node_costs);
const auto* node = GetCurrNode();
const auto& op_name = node->op();
auto it = op_to_cost_.find(op_name);
if (it == op_to_cost_.end()) {
it = op_to_cost_.emplace(op_name, Costs::ZeroCosts()).first;
}
auto& op_cost = it->second;
op_cost = CombineCosts(op_cost, node_costs);
// Update node and device states.
auto& node_state = node_map_[node];
auto& device = device_[node->device()];
device.nodes_executed.push_back(node);
// Node is scheduled when the device is available AND all the inputs are
// ready; hence, time_scheduled is time_ready if time_ready > device curr
// time.
node_state.time_scheduled =
std::max(device.GetCurrTime(), node_state.time_ready);
// Override device curr time with the time_scheduled.
device.device_costs.execution_time = node_state.time_scheduled;
device.device_costs = CombineCosts(device.device_costs, node_costs);
auto curr_time = device.GetCurrTime();
node_state.time_finished = curr_time;
// Update device's per-op cost.
{
auto it = device.op_to_cost.find(op_name);
if (it == device.op_to_cost.end()) {
it = device.op_to_cost.emplace(op_name, Costs::ZeroCosts()).first;
}
auto& op_cost = it->second;
op_cost = CombineCosts(op_cost, node_costs);
VLOG(2) << "Op scheduled -- name: " << node->name()
<< ", op: " << node->op() << ", device: " << node->device()
<< ", ready: " << node_state.time_ready.count()
<< ", scheduled: " << node_state.time_scheduled.count()
<< ", finished: " << node_state.time_finished.count();
// Increment num_inputs_ready of the output nodes.
for (auto* output : node_state.outputs) {
auto& output_state = node_map_[output];
output_state.num_inputs_ready++;
if (output_state.num_inputs_ready == output_state.inputs.size()) {
// This output node is now ready.
output_state.time_ready = curr_time;
ready_nodes_->AddNode(output);
}
}
// Increment num_outputs_executed of the input nodes.
for (auto* input : node_state.inputs) {
auto& input_state = node_map_[input];
input_state.num_outputs_executed++;
if (input_state.num_outputs_executed == input_state.outputs.size()) {
// All the outputs are executed; no reference to this input nodel
input_state.time_no_reference = curr_time;
// TODO(dyoon): collect device memory usage; note that this input node
// use device memory between time_scheduled and time_no_reference.
}
}
}
// Remove the current node; assume FIFO.
ready_nodes_->RemoveCurrNode();
return !ready_nodes_->Empty(); // True if not empty.
}
Costs VirtualScheduler::Summary() const {
// Print out basic execution summary.
VLOG(1) << "Expected execution time: " << graph_costs_.execution_time.count();
VLOG(1) << "Expected max memory: " << graph_costs_.max_memory;
VLOG(1) << "Expected max per-op buffers: " << graph_costs_.max_per_op_buffers;
VLOG(1) << "Expected max per-op streaming buffers: "
<< graph_costs_.max_per_op_streaming;
VLOG(1) << "Per-op execution time:";
for (const auto& op_cost_pair : op_to_cost_) {
const auto& op = op_cost_pair.first;
const auto& cost = op_cost_pair.second.execution_time.count();
if (cost) { // Skip printing out zero-cost ops.
VLOG(1) << " + " << op << " : " << cost;
}
}
// Print per device summary
VLOG(1) << "Devices:";
Costs critical_path_costs = Costs::ZeroCosts();
for (const auto& device : device_) {
const auto& name = device.first;
const auto& state = device.second;
VLOG(1) << "Device = " << name
<< ", num_nodes = " << state.nodes_executed.size()
<< ", execution_time = " << state.GetCurrTime().count();
VLOG(1) << "Per-op execution time:";
for (const auto& op_cost_pair : state.op_to_cost) {
const auto& op = op_cost_pair.first;
const auto& cost = op_cost_pair.second.execution_time.count();
if (cost) { // Skip printing out zero-cost ops.
VLOG(1) << " + " << op << " : " << cost;
}
}
if (critical_path_costs.execution_time <= state.GetCurrTime()) {
critical_path_costs = state.device_costs;
}
}
VLOG(1) << "Critical path execution time: "
<< critical_path_costs.execution_time.count();
return critical_path_costs;
}
} // end namespace grappler
} // end namespace tensorflow

View File

@ -0,0 +1,116 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_
#define THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_
#include <list>
#include <memory>
#include <unordered_map>
#include "tensorflow/core/grappler/costs/cost_estimator.h"
#include "tensorflow/core/grappler/grappler_item.h"
namespace tensorflow {
namespace grappler {
namespace {
struct NodeState {
std::vector<const NodeDef*> inputs;
std::vector<const NodeDef*> outputs;
int num_inputs_ready;
int num_outputs_executed;
Costs::Duration time_ready;
Costs::Duration time_scheduled;
Costs::Duration time_finished;
Costs::Duration time_no_reference;
// Node will be ready to be executed at time_ready, scheduled at
// time_scheduled, and finishes execution at time_finished.
// Between time_scheduled and time_no_reference, the node's output tensor
// needs to be on the device, using up device memory.
NodeState() {
num_inputs_ready = 0;
num_outputs_executed = 0;
time_ready = Costs::Duration::max();
time_scheduled = Costs::Duration::max();
time_finished = Costs::Duration::max();
time_no_reference = Costs::Duration::max();
}
};
struct DeviceState {
std::vector<const NodeDef*> nodes_executed;
Costs device_costs;
std::map<string, Costs> op_to_cost; // Per-op cost.
DeviceState() { device_costs = Costs::ZeroCosts(); }
Costs::Duration GetCurrTime() const { return device_costs.execution_time; }
};
// ReadyNodeManager (abstract class):
// Keeps ready nodes and picks the best one to be scheduled.
class ReadyNodeManager {
public:
ReadyNodeManager() {}
virtual ~ReadyNodeManager() {}
virtual void AddNode(const NodeDef* node) = 0;
virtual const NodeDef* GetCurrNode() const = 0;
virtual void RemoveCurrNode() = 0;
virtual bool Empty() const = 0;
};
class FIFOManager : public ReadyNodeManager {
public:
FIFOManager() : ReadyNodeManager() {}
~FIFOManager() override {}
void AddNode(const NodeDef* node) override { nodes_.push_back(node); }
const NodeDef* GetCurrNode() const override { return nodes_.front(); }
void RemoveCurrNode() override { nodes_.pop_front(); }
bool Empty() const override { return nodes_.empty(); }
private:
std::list<const NodeDef*> nodes_;
};
} // namespace
// The virtual scheduler emulates execution of nodes in a graph, considering
// dependencies, device, etc.
class VirtualScheduler {
public:
VirtualScheduler(const GraphDef& graph,
const std::vector<string>& fetch_nodes);
const NodeDef* GetCurrNode() const;
bool MarkCurrNodeExecuted(const Costs& node_costs);
Costs Summary() const;
private:
NodeState& GetNodeStateOrCreateIt(const NodeDef* node);
Costs graph_costs_; // Graph cost.
std::map<string, Costs> op_to_cost_; // Per-op cost.
std::unique_ptr<ReadyNodeManager> ready_nodes_;
std::unordered_map<const NodeDef*, NodeState> node_map_;
std::unordered_map<string, DeviceState> device_;
};
} // namespace grappler
} // end namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_

View File

@ -2109,7 +2109,9 @@ tf_kernel_library(
tf_kernel_library( tf_kernel_library(
name = "matrix_triangular_solve_op", name = "matrix_triangular_solve_op",
prefix = "matrix_triangular_solve_op", prefix = "matrix_triangular_solve_op",
deps = LINALG_DEPS, deps = LINALG_DEPS + if_cuda([
"//tensorflow/core/platform/default/build_config:cublas_plugin",
]),
) )
tf_kernel_library( tf_kernel_library(
@ -2350,6 +2352,8 @@ tf_kernel_library(
"//conditions:default": [], "//conditions:default": [],
}) + if_mkl([ }) + if_mkl([
"//third_party/mkl:intel_binary_blob", "//third_party/mkl:intel_binary_blob",
]) + if_cuda([
"//tensorflow/core/platform/default/build_config:cublas_plugin",
]), ]),
) )
@ -2630,6 +2634,7 @@ tf_kernel_library(
], ],
"//conditions:default": [], "//conditions:default": [],
}) + if_cuda([ }) + if_cuda([
"//tensorflow/core/platform/default/build_config:cublas_plugin",
"//tensorflow/core/platform/default/build_config:cudnn_plugin", "//tensorflow/core/platform/default/build_config:cudnn_plugin",
]), ]),
) )

View File

@ -24,28 +24,32 @@ limitations under the License.
#if !defined(_MSC_VER) #if !defined(_MSC_VER)
#define UNROLL _Pragma("unroll") #define UNROLL _Pragma("unroll")
#define NOUNROLL _Pragma("nounroll")
#else #else
#define UNROLL #define UNROLL
#define NOUNROLL
#endif #endif
namespace tensorflow { namespace tensorflow {
namespace { using Eigen::GpuDevice;
typedef Eigen::GpuDevice GPUDevice;
// A Cuda kernel to compute the depthwise convolution forward pass // A Cuda kernel to compute the depthwise convolution forward pass
// in NHWC format. // in NHWC format.
template <typename T> template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
int kKnownDepthMultiplier>
__global__ void DepthwiseConv2dGPUKernelNHWC(const DepthwiseArgs args, __global__ void DepthwiseConv2dGPUKernelNHWC(const DepthwiseArgs args,
const T* input, const T* filter, const T* input, const T* filter,
T* output, int num_outputs) { T* output, int num_outputs) {
const int in_rows = args.in_rows; const int in_rows = args.in_rows;
const int in_cols = args.in_cols; const int in_cols = args.in_cols;
const int in_depth = args.in_depth; const int in_depth = args.in_depth;
const int filter_rows = args.filter_rows; const int filter_rows =
const int filter_cols = args.filter_cols; kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight;
const int depth_multiplier = args.depth_multiplier; const int filter_cols =
kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth;
const int depth_multiplier =
kKnownDepthMultiplier < 0 ? args.depth_multiplier : kKnownDepthMultiplier;
const int stride = args.stride; const int stride = args.stride;
const int pad_rows = args.pad_rows; const int pad_rows = args.pad_rows;
const int pad_cols = args.pad_cols; const int pad_cols = args.pad_cols;
@ -114,16 +118,20 @@ __global__ void DepthwiseConv2dGPUKernelNHWC(const DepthwiseArgs args,
// A Cuda kernel to compute the depthwise convolution forward pass // A Cuda kernel to compute the depthwise convolution forward pass
// in NCHW format. // in NCHW format.
template <typename T> template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
int kKnownDepthMultiplier>
__global__ void DepthwiseConv2dGPUKernelNCHW(const DepthwiseArgs args, __global__ void DepthwiseConv2dGPUKernelNCHW(const DepthwiseArgs args,
const T* input, const T* filter, const T* input, const T* filter,
T* output, int num_outputs) { T* output, int num_outputs) {
const int in_rows = args.in_rows; const int in_rows = args.in_rows;
const int in_cols = args.in_cols; const int in_cols = args.in_cols;
const int in_depth = args.in_depth; const int in_depth = args.in_depth;
const int filter_rows = args.filter_rows; const int filter_rows =
const int filter_cols = args.filter_cols; kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight;
const int depth_multiplier = args.depth_multiplier; const int filter_cols =
kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth;
const int depth_multiplier =
kKnownDepthMultiplier < 0 ? args.depth_multiplier : kKnownDepthMultiplier;
const int stride = args.stride; const int stride = args.stride;
const int pad_rows = args.pad_rows; const int pad_rows = args.pad_rows;
const int pad_cols = args.pad_cols; const int pad_cols = args.pad_cols;
@ -235,29 +243,41 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(const DepthwiseArgs args,
} }
} }
} // namespace template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
int kKnownDepthMultiplier>
void LaunchDepthwiseConv2dGPU(const GpuDevice& d, const DepthwiseArgs args,
const T* input, const T* filter, T* output,
TensorFormat data_format) {
const int num_outputs =
args.batch * args.out_rows * args.out_cols * args.out_depth;
CudaLaunchConfig config = GetCudaLaunchConfig(num_outputs, d);
if (data_format == FORMAT_NHWC) {
DepthwiseConv2dGPUKernelNHWC<T, kKnownFilterWidth, kKnownFilterHeight,
kKnownDepthMultiplier>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
args, input, filter, output, num_outputs);
} else if (data_format == FORMAT_NCHW) {
DepthwiseConv2dGPUKernelNCHW<T, kKnownFilterWidth, kKnownFilterHeight,
kKnownDepthMultiplier>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
args, input, filter, output, num_outputs);
} else {
assert(false);
}
}
// A simple launch pad to launch the Cuda kernel for depthwise convolution. // A simple launch pad to launch the Cuda kernel for depthwise convolution.
template <typename T> template <typename T>
struct DepthwiseConv2dGPULaunch { struct DepthwiseConv2dGPULaunch {
static void Run(const GPUDevice& d, const DepthwiseArgs args, const T* input, static void Run(const GpuDevice& d, const DepthwiseArgs args, const T* input,
const T* filter, T* output, TensorFormat data_format) { const T* filter, T* output, TensorFormat data_format) {
// In this kernel, each thread is computing the gradients from one element if (args.filter_rows == 3 && args.filter_cols == 3 &&
// in the out_backprop. Note that one element in the out_backprop can map args.depth_multiplier == 1) {
// to multiple filter elements. LaunchDepthwiseConv2dGPU<T, 3, 3, 1>(d, args, input, filter, output,
const int num_outputs = data_format);
args.batch * args.out_rows * args.out_cols * args.out_depth;
CudaLaunchConfig config = GetCudaLaunchConfig(num_outputs, d);
if (data_format == FORMAT_NHWC) {
DepthwiseConv2dGPUKernelNHWC<T>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
args, input, filter, output, num_outputs);
} else if (data_format == FORMAT_NCHW) {
DepthwiseConv2dGPUKernelNCHW<T>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
args, input, filter, output, num_outputs);
} else { } else {
assert(false); LaunchDepthwiseConv2dGPU<T, -1, -1, -1>(d, args, input, filter, output,
data_format);
} }
} }
}; };
@ -266,18 +286,20 @@ template struct DepthwiseConv2dGPULaunch<float>;
template struct DepthwiseConv2dGPULaunch<double>; template struct DepthwiseConv2dGPULaunch<double>;
// A Cuda kernel to compute the depthwise convolution backprop w.r.t. input. // A Cuda kernel to compute the depthwise convolution backprop w.r.t. input.
template <typename T, int KNOWN_DEPTH_MULTIPLIER> template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
int kKnownDepthMultiplier>
__global__ void DepthwiseConv2dBackpropInputGPUKernelNHWC( __global__ void DepthwiseConv2dBackpropInputGPUKernelNHWC(
const DepthwiseArgs args, const T* out_backprop, const T* filter, const DepthwiseArgs args, const T* out_backprop, const T* filter,
T* in_backprop, int num_in_backprop) { T* in_backprop, int num_in_backprop) {
const int in_rows = args.in_rows; const int in_rows = args.in_rows;
const int in_cols = args.in_cols; const int in_cols = args.in_cols;
const int in_depth = args.in_depth; const int in_depth = args.in_depth;
const int filter_rows = args.filter_rows; const int filter_rows =
const int filter_cols = args.filter_cols; kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight;
const int depth_multiplier = KNOWN_DEPTH_MULTIPLIER == -1 const int filter_cols =
? args.depth_multiplier kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth;
: KNOWN_DEPTH_MULTIPLIER; const int depth_multiplier =
kKnownDepthMultiplier < 0 ? args.depth_multiplier : kKnownDepthMultiplier;
const int stride = args.stride; const int stride = args.stride;
const int pad_rows = args.pad_rows; const int pad_rows = args.pad_rows;
const int pad_cols = args.pad_cols; const int pad_cols = args.pad_cols;
@ -301,14 +323,12 @@ __global__ void DepthwiseConv2dBackpropInputGPUKernelNHWC(
tf_max(0, (in_c - filter_cols + pad_cols + stride) / stride); tf_max(0, (in_c - filter_cols + pad_cols + stride) / stride);
const int out_c_end = tf_min(out_cols - 1, (in_c + pad_cols) / stride); const int out_c_end = tf_min(out_cols - 1, (in_c + pad_cols) / stride);
#pragma nounroll NOUNROLL for (int out_r = out_r_start; out_r <= out_r_end; ++out_r) {
for (int out_r = out_r_start; out_r <= out_r_end; ++out_r) {
const int f_r = in_r + pad_rows - out_r * stride; const int f_r = in_r + pad_rows - out_r * stride;
const int temp_out_backprop_offset = const int temp_out_backprop_offset =
out_depth * out_cols * (out_r + out_rows * b); out_depth * out_cols * (out_r + out_rows * b);
const int temp_filter_offset = filter_cols * f_r; const int temp_filter_offset = filter_cols * f_r;
#pragma nounroll NOUNROLL for (int out_c = out_c_start; out_c <= out_c_end; ++out_c) {
for (int out_c = out_c_start; out_c <= out_c_end; ++out_c) {
const int f_c = in_c + pad_cols - out_c * stride; const int f_c = in_c + pad_cols - out_c * stride;
int filter_offset = int filter_offset =
depth_multiplier * (in_d + in_depth * (f_c + temp_filter_offset)); depth_multiplier * (in_d + in_depth * (f_c + temp_filter_offset));
@ -328,7 +348,8 @@ __global__ void DepthwiseConv2dBackpropInputGPUKernelNHWC(
} }
} }
template <typename T> template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
int kKnownDepthMultiplier>
__global__ void __launch_bounds__(1024) __global__ void __launch_bounds__(1024)
DepthwiseConv2dBackpropInputGPUKernelNCHW(const DepthwiseArgs args, DepthwiseConv2dBackpropInputGPUKernelNCHW(const DepthwiseArgs args,
const T* out_backprop, const T* out_backprop,
@ -337,9 +358,12 @@ __global__ void __launch_bounds__(1024)
const int in_rows = args.in_rows; const int in_rows = args.in_rows;
const int in_cols = args.in_cols; const int in_cols = args.in_cols;
const int in_depth = args.in_depth; const int in_depth = args.in_depth;
const int filter_rows = args.filter_rows; const int filter_rows =
const int filter_cols = args.filter_cols; kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight;
const int depth_multiplier = args.depth_multiplier; const int filter_cols =
kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth;
const int depth_multiplier =
kKnownDepthMultiplier < 0 ? args.depth_multiplier : kKnownDepthMultiplier;
const int stride = args.stride; const int stride = args.stride;
const int pad_rows = args.pad_rows; const int pad_rows = args.pad_rows;
const int pad_cols = args.pad_cols; const int pad_cols = args.pad_cols;
@ -395,34 +419,52 @@ __global__ void __launch_bounds__(1024)
} }
} }
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
int kKnownDepthMultiplier>
void LaunchDepthwiseConv2dBackpropInputGPU(const GpuDevice& d,
const DepthwiseArgs args,
const T* out_backprop,
const T* filter, T* in_backprop,
TensorFormat data_format) {
const int num_in_backprop =
args.batch * args.in_rows * args.in_cols * args.in_depth;
CudaLaunchConfig config = GetCudaLaunchConfig(num_in_backprop, d);
// Increase block count for when there are more warps/SM than threads/SM.
// TODO(csigg): this is pretty arbitraty and should be generalized using
// cudaOccupancyMaxPotentialBlockSize().
config.block_count *= 4;
if (data_format == FORMAT_NHWC) {
DepthwiseConv2dBackpropInputGPUKernelNHWC<
T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
args, out_backprop, filter, in_backprop, num_in_backprop);
} else if (data_format == FORMAT_NCHW) {
DepthwiseConv2dBackpropInputGPUKernelNCHW<
T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
args, out_backprop, filter, in_backprop, num_in_backprop);
} else {
assert(false);
}
}
// A simple launch pad to launch the Cuda kernel for depthwise convolution. // A simple launch pad to launch the Cuda kernel for depthwise convolution.
template <typename T> template <typename T>
struct DepthwiseConv2dBackpropInputGPULaunch { struct DepthwiseConv2dBackpropInputGPULaunch {
static void Run(const GPUDevice& d, const DepthwiseArgs args, static void Run(const GpuDevice& d, const DepthwiseArgs args,
const T* out_backprop, const T* filter, T* in_backprop, const T* out_backprop, const T* filter, T* in_backprop,
TensorFormat data_format) { TensorFormat data_format) {
const int num_in_backprop = if (args.depth_multiplier == 1) {
args.batch * args.in_rows * args.in_cols * args.in_depth; if (args.filter_rows == 3 && args.filter_cols == 3) {
LaunchDepthwiseConv2dBackpropInputGPU<T, 3, 3, 1>(
CudaLaunchConfig config = GetCudaLaunchConfig(num_in_backprop, d); d, args, out_backprop, filter, in_backprop, data_format);
// Increase block count for when there are more warps/SM than threads/SM.
config.block_count *= 4;
if (data_format == FORMAT_NHWC) {
if (args.depth_multiplier == 1) {
DepthwiseConv2dBackpropInputGPUKernelNHWC<T, 1>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
args, out_backprop, filter, in_backprop, num_in_backprop);
} else { } else {
DepthwiseConv2dBackpropInputGPUKernelNHWC<T, -1> LaunchDepthwiseConv2dBackpropInputGPU<T, -1, -1, 1>(
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>( d, args, out_backprop, filter, in_backprop, data_format);
args, out_backprop, filter, in_backprop, num_in_backprop);
} }
} else if (data_format == FORMAT_NCHW) {
DepthwiseConv2dBackpropInputGPUKernelNCHW<T>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
args, out_backprop, filter, in_backprop, num_in_backprop);
} else { } else {
assert(false); LaunchDepthwiseConv2dBackpropInputGPU<T, -1, -1, -1>(
d, args, out_backprop, filter, in_backprop, data_format);
} }
} }
}; };
@ -431,16 +473,20 @@ template struct DepthwiseConv2dBackpropInputGPULaunch<float>;
template struct DepthwiseConv2dBackpropInputGPULaunch<double>; template struct DepthwiseConv2dBackpropInputGPULaunch<double>;
// A Cuda kernel to compute the depthwise convolution backprop w.r.t. filter. // A Cuda kernel to compute the depthwise convolution backprop w.r.t. filter.
template <typename T> template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
int kKnownDepthMultiplier>
__global__ void DepthwiseConv2dBackpropFilterGPUKernelNHWC( __global__ void DepthwiseConv2dBackpropFilterGPUKernelNHWC(
const DepthwiseArgs args, const T* out_backprop, const T* input, const DepthwiseArgs args, const T* out_backprop, const T* input,
T* filter_backprop, int num_out_backprop) { T* filter_backprop, int num_out_backprop) {
const int in_rows = args.in_rows; const int in_rows = args.in_rows;
const int in_cols = args.in_cols; const int in_cols = args.in_cols;
const int in_depth = args.in_depth; const int in_depth = args.in_depth;
const int filter_rows = args.filter_rows; const int filter_rows =
const int filter_cols = args.filter_cols; kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight;
const int depth_multiplier = args.depth_multiplier; const int filter_cols =
kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth;
const int depth_multiplier =
kKnownDepthMultiplier < 0 ? args.depth_multiplier : kKnownDepthMultiplier;
const int stride = args.stride; const int stride = args.stride;
const int pad_rows = args.pad_rows; const int pad_rows = args.pad_rows;
const int pad_cols = args.pad_cols; const int pad_cols = args.pad_cols;
@ -518,16 +564,20 @@ __global__ void DepthwiseConv2dBackpropFilterGPUKernelNHWC(
} }
// A Cuda kernel to compute the depthwise convolution backprop w.r.t. filter. // A Cuda kernel to compute the depthwise convolution backprop w.r.t. filter.
template <typename T> template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
int kKnownDepthMultiplier>
__global__ void DepthwiseConv2dBackpropFilterGPUKernelNCHW( __global__ void DepthwiseConv2dBackpropFilterGPUKernelNCHW(
const DepthwiseArgs args, const T* out_backprop, const T* input, const DepthwiseArgs args, const T* out_backprop, const T* input,
T* filter_backprop, int num_out_backprop) { T* filter_backprop, int num_out_backprop) {
const int in_rows = args.in_rows; const int in_rows = args.in_rows;
const int in_cols = args.in_cols; const int in_cols = args.in_cols;
const int in_depth = args.in_depth; const int in_depth = args.in_depth;
const int filter_rows = args.filter_rows; const int filter_rows =
const int filter_cols = args.filter_cols; kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight;
const int depth_multiplier = args.depth_multiplier; const int filter_cols =
kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth;
const int depth_multiplier =
kKnownDepthMultiplier < 0 ? args.depth_multiplier : kKnownDepthMultiplier;
const int stride = args.stride; const int stride = args.stride;
const int pad_rows = args.pad_rows; const int pad_rows = args.pad_rows;
const int pad_cols = args.pad_cols; const int pad_cols = args.pad_cols;
@ -610,28 +660,44 @@ __global__ void DepthwiseConv2dBackpropFilterGPUKernelNCHW(
} }
} }
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
int kKnownDepthMultiplier>
void LaunchDepthwiseConv2dBackpropFilterGPU(const GpuDevice& d,
const DepthwiseArgs args,
const T* out_backprop,
const T* input, T* filter_backprop,
TensorFormat data_format) {
const int num_out_backprop =
args.batch * args.out_rows * args.out_cols * args.out_depth;
CudaLaunchConfig config = GetCudaLaunchConfig(num_out_backprop, d);
if (data_format == FORMAT_NHWC) {
DepthwiseConv2dBackpropFilterGPUKernelNHWC<
T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
args, out_backprop, input, filter_backprop, num_out_backprop);
} else if (data_format == FORMAT_NCHW) {
DepthwiseConv2dBackpropFilterGPUKernelNCHW<
T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
args, out_backprop, input, filter_backprop, num_out_backprop);
} else {
assert(false);
}
}
// A simple launch pad to launch the Cuda kernel for depthwise convolution. // A simple launch pad to launch the Cuda kernel for depthwise convolution.
template <typename T> template <typename T>
struct DepthwiseConv2dBackpropFilterGPULaunch { struct DepthwiseConv2dBackpropFilterGPULaunch {
static void Run(const GPUDevice& d, const DepthwiseArgs args, static void Run(const GpuDevice& d, const DepthwiseArgs args,
const T* out_backprop, const T* input, T* filter_backprop, const T* out_backprop, const T* input, T* filter_backprop,
TensorFormat data_format) { TensorFormat data_format) {
// In this kernel, each thread is computing the gradients for one element in if (args.filter_rows == 3 && args.filter_cols == 3 &&
// the out_backprop. args.depth_multiplier == 1) {
const int num_out_backprop = LaunchDepthwiseConv2dBackpropFilterGPU<T, 3, 3, 1>(
args.batch * args.out_rows * args.out_cols * args.out_depth; d, args, out_backprop, input, filter_backprop, data_format);
CudaLaunchConfig config = GetCudaLaunchConfig(num_out_backprop, d);
if (data_format == FORMAT_NHWC) {
DepthwiseConv2dBackpropFilterGPUKernelNHWC<T>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
args, out_backprop, input, filter_backprop, num_out_backprop);
} else if (data_format == FORMAT_NCHW) {
DepthwiseConv2dBackpropFilterGPUKernelNCHW<T>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
args, out_backprop, input, filter_backprop, num_out_backprop);
} else { } else {
assert(false); LaunchDepthwiseConv2dBackpropFilterGPU<T, -1, -1, -1>(
d, args, out_backprop, input, filter_backprop, data_format);
} }
} }
}; };

View File

@ -220,6 +220,10 @@ And the following comand line executes the `HelloTF` program on Windows:
<pre><b>java -cp libtensorflow-1.1.0-rc2.jar;. -Djava.library.path=jni HelloTF</b></pre> <pre><b>java -cp libtensorflow-1.1.0-rc2.jar;. -Djava.library.path=jni HelloTF</b></pre>
And the following comand line executes the `HelloTF` program on Windows:
<pre><b>java -cp libtensorflow-1.1.0-rc2.jar;. -Djava.library.path=jni HelloTF</b></pre>
If the program prints <tt>Hello from <i>version</i></tt>, you've successfully If the program prints <tt>Hello from <i>version</i></tt>, you've successfully
installed TensorFlow for Java and are ready to use the API. If the program installed TensorFlow for Java and are ready to use the API. If the program
outputs something else, check outputs something else, check

View File

@ -3,9 +3,9 @@
## Overview ## Overview
A selection of image classification models were tested across multiple platforms A selection of image classification models were tested across multiple platforms
to create a point of reference for the TensorFlow community. The methodology, to create a point of reference for the TensorFlow community. The
links to the benchmark scripts, and commands to reproduce the results are in the [Methodology](#methodology) section details how the test were executed and has
[Appendix](#appendix). links to the scripts used.
## Results for image classification models ## Results for image classification models
@ -120,19 +120,19 @@ VGG16 | replicated (with NCCL) | n/a
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16 GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
---- | ----------- | --------- | ---------- | ------- | ----- ---- | ----------- | --------- | ---------- | ------- | -----
1 | 142 | 238 | 95.6 | 2987 | 132 1 | 142 | 238 | 95.6 | 2987 | 154
2 | 284 | 479 | 187 | 5658 | 259 2 | 284 | 479 | 187 | 5658 | 295
4 | 569 | 948 | 374 | 10509 | 511 4 | 569 | 948 | 374 | 10509 | 584
8 | 1131 | 1886 | 744 | 17822 | 959 8 | 1131 | 1886 | 744 | 17822 | 1081
**Training real data** **Training real data**
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16 GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
---- | ----------- | --------- | ---------- | ------- | ----- ---- | ----------- | --------- | ---------- | ------- | -----
1 | 142 | 239 | 95.5 | 2890 | 132 1 | 142 | 239 | 95.5 | 2890 | 154
2 | 278 | 468 | 187 | 4448 | 245 2 | 278 | 468 | 187 | 4448 | 284
4 | 551 | 938 | 373 | 7105 | 466 4 | 551 | 938 | 373 | 7105 | 534
8 | 1079 | 1802 | 721 | N/A | 794 8 | 1079 | 1802 | 721 | N/A | 898
Training AlexNet with real data on 8 GPUs was excluded from the graph and table Training AlexNet with real data on 8 GPUs was excluded from the graph and table
above due to it maxing out the input pipeline. above due to it maxing out the input pipeline.
@ -145,19 +145,19 @@ The results below are all with a batch size of 32.
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | VGG16 GPUs | InceptionV3 | ResNet-50 | ResNet-152 | VGG16
---- | ----------- | --------- | ---------- | ----- ---- | ----------- | --------- | ---------- | -----
1 | 128 | 210 | 85.3 | 124 1 | 128 | 210 | 85.3 | 144
2 | 259 | 412 | 166 | 241 2 | 259 | 412 | 166 | 281
4 | 520 | 827 | 330 | 470 4 | 520 | 827 | 330 | 549
8 | 995 | 1623 | 643 | 738 8 | 995 | 1623 | 643 | 820
**Training real data** **Training real data**
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | VGG16 GPUs | InceptionV3 | ResNet-50 | ResNet-152 | VGG16
---- | ----------- | --------- | ---------- | ----- ---- | ----------- | --------- | ---------- | -----
1 | 130 | 208 | 85.0 | 124 1 | 130 | 208 | 85.0 | 144
2 | 257 | 403 | 163 | 221 2 | 257 | 403 | 163 | 253
4 | 507 | 814 | 325 | 401 4 | 507 | 814 | 325 | 457
8 | 966 | 1525 | 641 | 619 8 | 966 | 1525 | 641 | 690
## Details for Google Compute Engine (NVIDIA® Tesla® K80) ## Details for Google Compute Engine (NVIDIA® Tesla® K80)
@ -198,19 +198,19 @@ The configuration used for each model was `variable_update` equal to
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16 GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
---- | ----------- | --------- | ---------- | ------- | ----- ---- | ----------- | --------- | ---------- | ------- | -----
1 | 30.5 | 56.8 | 20.8 | 656 | 30.3 1 | 30.5 | 56.8 | 20.8 | 656 | 35.4
2 | 57.8 | 107 | 39.1 | 1210 | 56.2 2 | 57.8 | 107 | 39.1 | 1209 | 64.8
4 | 116 | 212 | 77.2 | 2330 | 106 4 | 116 | 212 | 77.2 | 2328 | 120
8 | 227 | 419 | 151 | 4640 | 222 8 | 227 | 419 | 151 | 4640 | 234
**Training real data** **Training real data**
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16 GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
---- | ----------- | --------- | ---------- | ------- | ----- ---- | ----------- | --------- | ---------- | ------- | -----
1 | 30.6 | 56.7 | 20.7 | 639 | 30.2 1 | 30.6 | 56.7 | 20.7 | 639 | 34.2
2 | 58.4 | 107 | 39.0 | 1136 | 55.5 2 | 58.4 | 107 | 39.0 | 1136 | 62.9
4 | 115 | 211 | 77.3 | 2067 | 106 4 | 115 | 211 | 77.3 | 2067 | 118
8 | 225 | 422 | 151 | 4056 | 213 8 | 225 | 422 | 151 | 4056 | 230
### Other Results ### Other Results
@ -227,10 +227,10 @@ GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32) GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
---- | --------------------------- | ------------------------- ---- | --------------------------- | -------------------------
1 | 29.5 | 53.6 1 | 29.5 | 53.6
2 | 55.4 | 102 2 | 55.4 | 102
4 | 110 | 201 4 | 110 | 201
8 | 216 | 387 8 | 216 | 387
## Details for Amazon EC2 (NVIDIA® Tesla® K80) ## Details for Amazon EC2 (NVIDIA® Tesla® K80)
@ -279,19 +279,19 @@ VGG16 | parameter_server | gpu
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16 GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
---- | ----------- | --------- | ---------- | ------- | ----- ---- | ----------- | --------- | ---------- | ------- | -----
1 | 30.8 | 56.3 | 20.9 | 684 | 32.4 1 | 30.8 | 56.3 | 20.9 | 684 | 36.3
2 | 58.7 | 108 | 39.3 | 1244 | 61.5 2 | 58.7 | 108 | 39.3 | 1244 | 69.4
4 | 117 | 217 | 79.1 | 2479 | 123 4 | 117 | 217 | 79.1 | 2479 | 141
8 | 230 | 419 | 156 | 4853 | 234 8 | 230 | 419 | 156 | 4853 | 260
**Training real data** **Training real data**
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16 GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
---- | ----------- | --------- | ---------- | ------- | ----- ---- | ----------- | --------- | ---------- | ------- | -----
1 | 30.5 | 56.0 | 20.6 | 674 | 32.0 1 | 30.5 | 56.0 | 20.6 | 674 | 36.3
2 | 58.7 | 107 | 39.0 | 1227 | 61.0 2 | 59.0 | 107 | 39.0 | 1227 | 67.5
4 | 118 | 205 | 77.9 | 2201 | 120 4 | 118 | 205 | 77.9 | 2201 | 136
8 | 228 | 405 | 152 | N/A | 191 8 | 228 | 405 | 152 | N/A | 242
Training AlexNet with real data on 8 GPUs was excluded from the graph and table Training AlexNet with real data on 8 GPUs was excluded from the graph and table
above due to our EFS setup not providing enough throughput. above due to our EFS setup not providing enough throughput.
@ -393,63 +393,17 @@ GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
32 | 820 | 1265 32 | 820 | 1265
64 | 1608 | 2623 64 | 1608 | 2623
## Appendix
### Executing benchmark tests ## Methodology
The [benchmark code](https://github.com/tensorflow/benchmarks/tree/master/scripts/tf_cnn_benchmarks) This [script](https://github.com/tensorflow/benchmarks/tree/master/scripts/tf_cnn_benchmarks)
was created to be used for benchmarking TensorFlow as well as used as a tool to was run on the various platforms to generate the above results.
test hardware platforms. Techniques used in the benchmark scripts are detailed @{$performance_models$High-Performance Models} details techniques in the script
in @{$performance_models$High-Performance Models}. along with examples of how to execute the script.
There are two ways to execute the benchmark code: In order to create results that are as repeatable as possible, each test was run
5 times and then the times were averaged together. GPUs are run in their default
1. Execute [tf_cnn_benchmarks.py](https://github.com/tensorflow/benchmarks/tree/master/scripts/tf_cnn_benchmarks/tf_cnn_benchmarks.py) state on the given platform. For NVIDIA® Tesla® K80 this means leaving on [GPU
directly. Boost](https://devblogs.nvidia.com/parallelforall/increase-performance-gpu-boost-k80-autoboost/).
2. Utilize the [scripts](https://github.com/tensorflow/benchmarks/tree/master/scripts/tf_cnn_benchmarks/main.py) For each test, 10 warmup steps are done and then the next 100 steps are
that helps pick the correct config for each platform executes averaged.
`tf_cnn_benchmarks.py`.
The wrapper is suggested as a starting point. Then investigate the variety of
options available in `tf_cnn_benchmarks.py`. Below are a couple examples of
using the wrapper.
**Single Server**
This example illustrates training ResNet-50 on a single instance with 8 GPUs.
The `system` flag is used to determine the optimal configuration. The
supported values are gce, aws, and dgx1. If `system` is not passed, the best
config for the most widely available hardware is used.
```bash
python main.py --model=resnet50 --num_gpus=8
python main.py --system=aws --model=resnet50 --num_gpus=8
```
**Distributed**
This example illustrates training ResNet-50 on 2 hosts, e.g. host_0 (10.0.0.1)
and host_1 (10.0.0.2), with 8 GPUs each on AWS (Amazon EC2).
```bash
# Run the following commands on host_0 (10.0.0.1):
$ python main.py --system=aws --model=resnet50 --job_name=worker
--hosts=10.0.0.1,10.0.0.2 --task_index=0
$ python main.py --system=aws --model=resnet50 --job_name=ps
--hosts=10.0.0.1,10.0.0.2 --task_index=0
# Run the following commands on host_1 (10.0.0.2):
$ python main.py --system=aws --model=resnet50 --job_name=worker
--hosts=10.0.0.1,10.0.0.2 --task_index=1
$ python main.py --system=aws --model=resnet50 --job_name=ps
--hosts=10.0.0.1,10.0.0.2 --task_index=1
```
### Methodology
Unless otherwise stated, each test is run 5 times and then the times are
averaged together. GPUs are run in their default state on the given platform.
For NVIDIA® Tesla® K80 this means leaving on [GPU
Boost](https://devblogs.nvidia.com/parallelforall/increase-performance-gpu-boost-k80-autoboost/)
unless it has been turned off by the provider. For a given test, 10 warmup steps
are done and then the next 100 steps are averaged.

View File

@ -9,7 +9,7 @@ deeper with techniques detailed in @{$performance_models$High-Performance Models
practices for optimizing your TensorFlow code. practices for optimizing your TensorFlow code.
* @{$performance_models$High-Performance Models}, which contains a collection * @{$performance_models$High-Performance Models}, which contains a collection
advanced techniques to build highly scalable models targeting different of advanced techniques to build highly scalable models targeting different
system types and network topologies. system types and network topologies.
* @{$benchmarks$Benchmarks}, which contains a collection of benchmark * @{$benchmarks$Benchmarks}, which contains a collection of benchmark

View File

@ -14,8 +14,8 @@ input pipeline issues and best practices. We found that using @{tf.FIFOQueue}
and @{tf.train.queue_runner} could not saturate multiple current generation GPUs and @{tf.train.queue_runner} could not saturate multiple current generation GPUs
when using large inputs and processing with higher samples per second, such when using large inputs and processing with higher samples per second, such
as training ImageNet with [AlexNet](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf). as training ImageNet with [AlexNet](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf).
This is due to the the use of Python threads as its underlying implementation. This is due to the use of Python threads as its underlying implementation. The
The overhead of Python threads is too large. overhead of Python threads is too large.
Another approach, which we have implemented in the Another approach, which we have implemented in the
[scripts](https://github.com/tensorflow/benchmarks/tree/master/scripts/tf_cnn_benchmarks), [scripts](https://github.com/tensorflow/benchmarks/tree/master/scripts/tf_cnn_benchmarks),
@ -327,3 +327,96 @@ free.
The downside is that all the weights read are from the previous training step. The downside is that all the weights read are from the previous training step.
So it is a different algorithm from SGD. But it is possible to improve its So it is a different algorithm from SGD. But it is possible to improve its
convergence by adjusting learning rate and other hyperparameters. convergence by adjusting learning rate and other hyperparameters.
## Executing the script
This section lists the core command line arguments and a few basic examples for
executing the main script
([tf_cnn_benchmarks.py](https://github.com/tensorflow/benchmarks/tree/master/scripts/tf_cnn_benchmarks/tf_cnn_benchmarks.py)).
> Note: `tf_cnn_benchmarks.py` uses the config `force_gpu_compatible`,
> which was introduced after TensorFlow 1.1. Until TensorFlow 1.2 is released
> building from source is advised.
#### Base command line arguments
* **`model`**: Model to use, e.g. `resnet50`, `inception3`, `vgg16`, and
`alexnet`.
* **`num_gpus`**: Number of GPUs to use.
* **`data_dir`**: Path to data to process. If not set, synthetic data is used.
To use Imagenet data use these
[instructions(https://github.com/tensorflow/models/tree/master/inception#getting-started)
as a starting point.
* **`batch_size`**: Batch size for each GPU.
* **`variable_update`**: The method for managing variables: `parameter_server`
,`replicated`, `distributed_replicated`, `independent`
* **`local_parameter_device`**: Device to use as parameter server: `cpu` or
`gpu`.
#### Single instance examples
```bash
# VGG16 training ImageNet with 8 GPUs using arguments that optimize for
# Google Compute Engine.
python tf_cnn_benchmarks.py --local_parameter_device=cpu --num_gpus=8 \
--batch_size=32 --model=vgg16 --data_dir=/home/ubuntu/imagenet/train \
--variable_update=parameter_server --nodistortions
# VGG16 training synthetic ImageNet data with 8 GPUs using arguments that
# optimize for the NVIDIA DGX-1.
python tf_cnn_benchmarks.py --local_parameter_device=gpu --num_gpus=8 \
--batch_size=64 --model=vgg16 --variable_update=replicated --use_nccl=True
# VGG16 training ImageNet data with 8 GPUs using arguments that optimize for
# Amazon EC2.
python tf_cnn_benchmarks.py --local_parameter_device=gpu --num_gpus=8 \
--batch_size=64 --model=vgg16 --variable_update=parameter_server
# ResNet-50 training ImageNet data with 8 GPUs using arguments that optimize for
# Amazon EC2.
python tf_cnn_benchmarks.py --local_parameter_device=gpu --num_gpus=8 \
--batch_size=64 --model=resnet50 --variable_update=replicated --use_nccl=False
```
#### Distributed command line arguments
* **`ps_hosts`**: Comma separated list of hosts to use as parameter servers
in the format of ```<host>:port```, e.g. ```10.0.0.2:50000```.
* **`worker_hosts`**: Comma separated list of hosts to use as workers in the
format of ```<host>:port```, e.g. ```10.0.0.2:50001```.
* **`task_index`**: Index of the host in the list of `ps_hosts` or
`worker_hosts` being started.
* **`job_name`**: Type of job, e.g `ps` or `worker`
#### Distributed examples
Below is an example of training ResNet-50 on 2 hosts: host_0 (10.0.0.1) and
host_1 (10.0.0.2). The example uses synthetic data. To use real data pass the
`--data_dir` argument.
```bash
# Run the following commands on host_0 (10.0.0.1):
python tf_cnn_benchmarks.py --local_parameter_device=gpu --num_gpus=8 \
--batch_size=64 --model=resnet50 --variable_update=distributed_replicated \
--job_name=worker --ps_hosts=10.0.0.1:50000,10.0.0.2:50000 \
--worker_hosts=10.0.0.1:50001,10.0.0.2:50001 --task_index=0
python tf_cnn_benchmarks.py --local_parameter_device=gpu --num_gpus=8 \
--batch_size=64 --model=resnet50 --variable_update=distributed_replicated \
--job_name=ps --ps_hosts=10.0.0.1:50000,10.0.0.2:50000 \
--worker_hosts=10.0.0.1:50001,10.0.0.2:50001 --task_index=0
# Run the following commands on host_1 (10.0.0.2):
python tf_cnn_benchmarks.py --local_parameter_device=gpu --num_gpus=8 \
--batch_size=64 --model=resnet50 --variable_update=distributed_replicated \
--job_name=worker --ps_hosts=10.0.0.1:50000,10.0.0.2:50000 \
--worker_hosts=10.0.0.1:50001,10.0.0.2:50001 --task_index=1
python tf_cnn_benchmarks.py --local_parameter_device=gpu --num_gpus=8 \
--batch_size=64 --model=resnet50 --variable_update=distributed_replicated \
--job_name=ps --ps_hosts=10.0.0.1:50000,10.0.0.2:50000 \
--worker_hosts=10.0.0.1:50001,10.0.0.2:50001 --task_index=1
```

View File

@ -5,7 +5,7 @@ in the way described in the @{$variables$Variables HowTo}.
But when building complex models you often need to share large sets of But when building complex models you often need to share large sets of
variables and you might want to initialize all of them in one place. variables and you might want to initialize all of them in one place.
This tutorial shows how this can be done using `tf.variable_scope()` and This tutorial shows how this can be done using `tf.variable_scope()` and
the `tf.get_variable()`. `tf.get_variable()`.
## The Problem ## The Problem
@ -368,6 +368,6 @@ sequence-to-sequence models.
File | What's in it? File | What's in it?
--- | --- --- | ---
`models/tutorials/image/cifar10/cifar10.py` | Model for detecting objects in images. `tutorials/image/cifar10/cifar10.py` | Model for detecting objects in images.
`models/tutorials/rnn/rnn_cell.py` | Cell functions for recurrent neural networks. `tutorials/rnn/rnn_cell.py` | Cell functions for recurrent neural networks.
`models/tutorials/rnn/seq2seq.py` | Functions for building sequence-to-sequence models. `tutorials/rnn/seq2seq.py` | Functions for building sequence-to-sequence models.

View File

@ -83,7 +83,7 @@ for details. It consists of 1,068,298 learnable parameters and requires about
## Code Organization ## Code Organization
The code for this tutorial resides in The code for this tutorial resides in
[`tensorflow_models/tutorials/image/cifar10/`](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10/). [`models/tutorials/image/cifar10/`](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10/).
File | Purpose File | Purpose
--- | --- --- | ---

View File

@ -348,12 +348,6 @@ class BaseDebugWrapperSession(session.SessionInterface):
_check_type(sess, session.BaseSession) _check_type(sess, session.BaseSession)
# TODO(cais): Remove this check once tfdbg is integrated with GrpcSession.
if sess.sess_str:
raise NotImplementedError(
"Non-DirectSession support is not available from TensorFlow "
"Debugger yet (sess_str=%s)" % sess.sess_str)
# The session being wrapped. # The session being wrapped.
self._sess = sess self._sess = sess
self._thread_name_filter_pattern = (re.compile(thread_name_filter) self._thread_name_filter_pattern = (re.compile(thread_name_filter)

View File

@ -384,18 +384,6 @@ class DebugWrapperSessionTest(test_util.TensorFlowTestCase):
["a_init", "b_init"], ["a_init", "b_init"],
[datum.node_name for datum in dump.dumped_tensor_data]) [datum.node_name for datum in dump.dumped_tensor_data])
def testUsingNonDirectSessionRaisesNotImplementedError(self):
# TODO(cais): Remove this test once tfdbg is integrated with GrpcSession.
fake_non_direct_session = session.Session()
fake_non_direct_session._target = "foo"
with self.assertRaisesRegexp(
NotImplementedError,
r"Non-DirectSession support is not available from TensorFlow Debugger "
r"yet \(sess_str=foo\)"):
TestDebugWrapperSession(
fake_non_direct_session, self._dump_root, self._observer)
if __name__ == "__main__": if __name__ == "__main__":
googletest.main() googletest.main()

View File

@ -139,6 +139,82 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest from tensorflow.python.util import nest
def make_input_layer(features,
feature_columns,
weight_collections=None,
trainable=True):
"""Returns a dense `Tensor` as input layer based on given `feature_columns`.
Generally a single example in training data is described with FeatureColumns.
At the first layer of the model, this column oriented data should be converted
to a single `Tensor`.
Example:
```python
price = numeric_column('price')
keywords_embedded = embedding_column(
categorical_column_with_hash_bucket("keywords", 10K), dimensions=16)
all_feature_columns = [price, keywords_embedded, ...]
dense_tensor = make_input_layer(features, all_feature_columns)
for units in [128, 64, 32]:
dense_tensor = tf.layers.dense(dense_tensor, units, tf.nn.relu)
prediction = tf.layers.dense(dense_tensor, 1)
```
Args:
features: A mapping from key to tensors. `FeatureColumn`s look up via these
keys. For example `numeric_column('price') will look at 'price' key in
this dict. Values can be a `SparseTensor` or a `Tensor` depends on
corresponding `FeatureColumn`.
feature_columns: An iterable containing all the `FeatureColumn`s. All items
should be instances of classes derived from `_DenseColumn` such as
`numeric_column`, `embedding_column`, `bucketized_column`,
`indicator_column`. If you have categorical features, you can wrap them
with with an `embedding_column` or `indicator_column`.
weight_collections: A list of collection names to which the Variable will be
added. Note that, variables will also be added to collections
`tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`.
trainable: If `True` also add the variable to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
Returns:
A `Tensor` which represents input layer of a model. Its shape
is (batch_size, first_layer_dimension) and its dtype is `float32`.
first_layer_dimension is determined based on given `feature_columns`.
Raises:
ValueError: if an item in `feature_columns` is not a `_DenseColumn`.
"""
_check_feature_columns(feature_columns)
for column in feature_columns:
if not isinstance(column, _DenseColumn):
raise ValueError(
'Items of feature_columns must be a _DenseColumn. '
'You can wrap a categorical column with an '
'embedding_column or indicator_column. Given: {}'.format(column))
weight_collections = list(weight_collections or [])
if ops.GraphKeys.GLOBAL_VARIABLES not in weight_collections:
weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES)
if ops.GraphKeys.MODEL_VARIABLES not in weight_collections:
weight_collections.append(ops.GraphKeys.MODEL_VARIABLES)
with variable_scope.variable_scope(
None, default_name='make_input_layer', values=features.values()):
builder = _LazyBuilder(features)
output_tensors = []
for column in sorted(feature_columns, key=lambda x: x.name):
with variable_scope.variable_scope(None, default_name=column.name):
tensor = column._get_dense_tensor( # pylint: disable=protected-access
builder,
weight_collections=weight_collections,
trainable=trainable)
num_elements = column._variable_shape.num_elements() # pylint: disable=protected-access
batch_size = array_ops.shape(tensor)[0]
tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements))
output_tensors.append(tensor)
return array_ops.concat(output_tensors, 1)
def make_linear_model(features, def make_linear_model(features,
feature_columns, feature_columns,
units=1, units=1,
@ -156,10 +232,21 @@ def make_linear_model(features,
while `make_input_layer` explicitly requires wrapping each of them with an while `make_input_layer` explicitly requires wrapping each of them with an
`embedding_column` or an `indicator_column`. `embedding_column` or an `indicator_column`.
Example:
```python
price = numeric_column('price')
price_buckets = bucketized_column(price, boundaries=[0., 10., 100., 1000.])
keywords = categorical_column_with_hash_bucket("keywords", 10K)
all_feature_columns = [price_buckets, keywords, ...]
prediction = make_linear_model(features, all_feature_columns)
```
Args: Args:
features: A mapping from key to tensors. 'string' key means a base feature. features: A mapping from key to tensors. `FeatureColumn`s look up via these
It can have `_FeatureColumn` as a key too. That means that FeatureColumn keys. For example `numeric_column('price')` will look at 'price' key in
is already transformed by the input pipeline. this dict. Values are `Tensor` or `SparseTensor` depending on
corresponding `FeatureColumn`.
feature_columns: An iterable containing all the FeatureColumns. All items feature_columns: An iterable containing all the FeatureColumns. All items
should be instances of classes derived from FeatureColumn. should be instances of classes derived from FeatureColumn.
units: units: An integer, dimensionality of the output space. Default units: units: An integer, dimensionality of the output space. Default
@ -191,22 +278,23 @@ def make_linear_model(features,
raise ValueError('Items of feature_columns must be either a _DenseColumn ' raise ValueError('Items of feature_columns must be either a _DenseColumn '
'or _CategoricalColumn. Given: {}'.format(column)) 'or _CategoricalColumn. Given: {}'.format(column))
weight_collections = list(weight_collections or []) weight_collections = list(weight_collections or [])
weight_collections += [ if ops.GraphKeys.GLOBAL_VARIABLES not in weight_collections:
ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.MODEL_VARIABLES weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES)
] if ops.GraphKeys.MODEL_VARIABLES not in weight_collections:
weight_collections.append(ops.GraphKeys.MODEL_VARIABLES)
with variable_scope.variable_scope( with variable_scope.variable_scope(
None, default_name='make_linear_model', values=features.values()): None, default_name='make_linear_model', values=features.values()):
weigthed_sums = [] weigthed_sums = []
builder = _LazyBuilder(features) builder = _LazyBuilder(features)
for column in sorted(feature_columns, key=lambda x: x.name): for column in sorted(feature_columns, key=lambda x: x.name):
with variable_scope.variable_scope(None, default_name=column.name): with variable_scope.variable_scope(None, default_name=column.name):
if isinstance(column, _DenseColumn): if isinstance(column, _CategoricalColumn):
weigthed_sums.append(_create_dense_column_weighted_sum(
column, builder, units, weight_collections, trainable))
else:
weigthed_sums.append(_create_categorical_column_weighted_sum( weigthed_sums.append(_create_categorical_column_weighted_sum(
column, builder, units, sparse_combiner, weight_collections, column, builder, units, sparse_combiner, weight_collections,
trainable)) trainable))
else:
weigthed_sums.append(_create_dense_column_weighted_sum(
column, builder, units, weight_collections, trainable))
predictions_no_bias = math_ops.add_n( predictions_no_bias = math_ops.add_n(
weigthed_sums, name='weighted_sum_no_bias') weigthed_sums, name='weighted_sum_no_bias')
bias = variable_scope.get_variable( bias = variable_scope.get_variable(
@ -228,7 +316,8 @@ def numeric_column(key,
normalizer_fn=None): normalizer_fn=None):
"""Represents real valued or numerical features. """Represents real valued or numerical features.
An example: Example:
```python ```python
price = numeric_column('price') price = numeric_column('price')
all_feature_columns = [price, ...] all_feature_columns = [price, ...]
@ -237,7 +326,7 @@ def numeric_column(key,
# or # or
bucketized_price = bucketized_column(price, boundaries=[...]) bucketized_price = bucketized_column(price, boundaries=[...])
all_feature_columns = [bucketized_price, ...] all_feature_columns = [bucketized_price, ...]
linear_prediction, _, _ = make_linear_model(features, all_feature_columns) linear_prediction = make_linear_model(features, all_feature_columns)
``` ```
@ -291,6 +380,56 @@ def numeric_column(key,
normalizer_fn=normalizer_fn) normalizer_fn=normalizer_fn)
def bucketized_column(source_column, boundaries):
"""Represents discretized dense input.
Buckets include the left boundary, and exclude the right boundary. Namely,
`boundaries=[0., 1., 2.]` generates buckets `(-inf, 0.)`, `[0., 1.)`,
`[1., 2.)`, and `[2., +inf)`.
Example:
```python
price = numeric_column('price')
bucketized_price = bucketized_column(price, boundaries=[...])
all_feature_columns = [bucketized_price, ...]
linear_prediction = make_linear_model(features, all_feature_columns)
# or
all_feature_columns = [bucketized_price, ...]
dense_tensor = make_input_layer(features, all_feature_columns)
```
Args:
source_column: A one-dimensional dense column which is generated with
`numeric_column`.
boundaries: A sorted list or tuple of floats specifying the boundaries.
Returns:
A `_BucketizedColumn`.
Raises:
ValueError: If `source_column` is not a numeric column, or if it is not
one-dimensional.
ValueError: If `boundaries` is not a sorted list or tuple.
"""
if not isinstance(source_column, _NumericColumn):
raise ValueError(
'source_column must be a column generated with numeric_column(). '
'Given: {}'.format(source_column))
if len(source_column.shape) > 1:
raise ValueError(
'source_column must be one-dimensional column. '
'Given: {}'.format(source_column))
if (not boundaries or
not (isinstance(boundaries, list) or isinstance(boundaries, tuple))):
raise ValueError('boundaries must be a sorted list.')
for i in range(len(boundaries) - 1):
if boundaries[i] >= boundaries[i + 1]:
raise ValueError('boundaries must be a sorted list.')
return _BucketizedColumn(source_column, tuple(boundaries))
def categorical_column_with_hash_bucket(key, def categorical_column_with_hash_bucket(key,
hash_bucket_size, hash_bucket_size,
dtype=dtypes.string): dtype=dtypes.string):
@ -300,11 +439,12 @@ def categorical_column_with_hash_bucket(key,
want to distribute your inputs into a finite number of buckets by hashing. want to distribute your inputs into a finite number of buckets by hashing.
output_id = Hash(input_feature_string) % bucket_size output_id = Hash(input_feature_string) % bucket_size
An example: Example:
```python ```python
keywords = categorical_column_with_hash_bucket("keywords", 10K) keywords = categorical_column_with_hash_bucket("keywords", 10K)
linear_prediction, _, _ = make_linear_model(features, all_feature_columns)
all_feature_columns = [keywords, ...] all_feature_columns = [keywords, ...]
linear_prediction = make_linear_model(features, all_feature_columns)
# or # or
keywords_embedded = embedding_column(keywords, 16) keywords_embedded = embedding_column(keywords, 16)
@ -422,7 +562,7 @@ class _DenseColumn(_FeatureColumn):
@abc.abstractproperty @abc.abstractproperty
def _variable_shape(self): def _variable_shape(self):
"""Returns shape of variable which is compatible with _get_dense_tensor.""" """Returns a `TensorShape` of variable compatible with _get_dense_tensor."""
pass pass
@abc.abstractmethod @abc.abstractmethod
@ -431,6 +571,7 @@ class _DenseColumn(_FeatureColumn):
The output of this function will be used by model-buildier-functions. For The output of this function will be used by model-buildier-functions. For
example the pseudo code of `make_input_layer` will be like that: example the pseudo code of `make_input_layer` will be like that:
```python ```python
def make_input_layer(features, feature_columns, ...): def make_input_layer(features, feature_columns, ...):
outputs = [fc._get_dense_tensor(...) for fc in feature_columns] outputs = [fc._get_dense_tensor(...) for fc in feature_columns]
@ -454,7 +595,7 @@ def _create_dense_column_weighted_sum(
builder, builder,
weight_collections=weight_collections, weight_collections=weight_collections,
trainable=trainable) trainable=trainable)
num_elements = tensor_shape.TensorShape(column._variable_shape).num_elements() # pylint: disable=protected-access num_elements = column._variable_shape.num_elements() # pylint: disable=protected-access
batch_size = array_ops.shape(tensor)[0] batch_size = array_ops.shape(tensor)[0]
tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements)) tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements))
weight = variable_scope.get_variable( weight = variable_scope.get_variable(
@ -566,12 +707,15 @@ class _LazyBuilder(object):
"""Creates a `_LazyBuilder`. """Creates a `_LazyBuilder`.
Args: Args:
features: A mapping from feature column to tensors. A `string` key features: A mapping from feature column to objects that are `Tensor` or
`SparseTensor`, or can be converted to same via
`sparse_tensor.convert_to_tensor_or_sparse_tensor`. A `string` key
signifies a base feature (not-transformed). A `FeatureColumn` key signifies a base feature (not-transformed). A `FeatureColumn` key
means that this `Tensor` is the output of an existing `FeatureColumn` means that this `Tensor` is the output of an existing `FeatureColumn`
which can be reused. which can be reused.
""" """
self._columns_to_tensors = features.copy() self._features = features.copy()
self._feature_tensors = {}
def get(self, key): def get(self, key):
"""Returns a `Tensor` for the given key. """Returns a `Tensor` for the given key.
@ -591,9 +735,16 @@ class _LazyBuilder(object):
ValueError: if key is not found or a transformed `Tensor` cannot be ValueError: if key is not found or a transformed `Tensor` cannot be
computed. computed.
""" """
if key in self._columns_to_tensors: if key in self._feature_tensors:
# Feature_column is already transformed or it's a raw feature. # FeatureColumn is already transformed or converted.
return self._columns_to_tensors[key] return self._feature_tensors[key]
if key in self._features:
# FeatureColumn is a raw feature.
feature_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
self._features[key])
self._feature_tensors[key] = feature_tensor
return feature_tensor
if not isinstance(key, (str, _FeatureColumn)): if not isinstance(key, (str, _FeatureColumn)):
raise TypeError('"key" must be either a "str" or "_FeatureColumn". ' raise TypeError('"key" must be either a "str" or "_FeatureColumn". '
@ -604,11 +755,13 @@ class _LazyBuilder(object):
column = key column = key
logging.debug('Transforming feature_column %s.', column) logging.debug('Transforming feature_column %s.', column)
transformed = column._transform_feature(self) # pylint: disable=protected-access # pylint: disable=protected-access
transformed = column._transform_feature(self)
# pylint: enable=protected-access
if transformed is None: if transformed is None:
raise ValueError('Column {} is not supported.'.format(column.name)) raise ValueError('Column {} is not supported.'.format(column.name))
self._columns_to_tensors[column] = transformed self._feature_tensors[column] = transformed
return self._columns_to_tensors[column] return transformed
def _check_feature_columns(feature_columns): def _check_feature_columns(feature_columns):
@ -660,7 +813,7 @@ class _NumericColumn(_DenseColumn,
@property @property
def _variable_shape(self): def _variable_shape(self):
return self.shape return tensor_shape.TensorShape(self.shape)
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
del weight_collections del weight_collections
@ -668,6 +821,74 @@ class _NumericColumn(_DenseColumn,
return inputs.get(self) return inputs.get(self)
class _BucketizedColumn(_DenseColumn, _CategoricalColumn,
collections.namedtuple('_BucketizedColumn', [
'source_column', 'boundaries'])):
"""See `bucketized_column`."""
@property
def name(self):
return '{}_bucketized'.format(self.source_column.name)
@property
def _parse_example_config(self):
return self.source_column._parse_example_config # pylint: disable=protected-access
def _transform_feature(self, inputs):
source_tensor = inputs.get(self.source_column)
return math_ops._bucketize( # pylint: disable=protected-access
source_tensor,
boundaries=self.boundaries)
@property
def _variable_shape(self):
return tensor_shape.TensorShape(
tuple(self.source_column.shape) + (len(self.boundaries) + 1,))
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
del weight_collections
del trainable
input_tensor = inputs.get(self)
return array_ops.one_hot(
indices=math_ops.to_int64(input_tensor),
depth=len(self.boundaries) + 1,
on_value=1.,
off_value=0.)
@property
def _num_buckets(self):
# By construction, source_column is always one-dimensional.
return (len(self.boundaries) + 1) * self.source_column.shape[0]
def _get_sparse_tensors(self, inputs, weight_collections=None,
trainable=None):
input_tensor = inputs.get(self)
batch_size = array_ops.shape(input_tensor)[0]
# By construction, source_column is always one-dimensional.
source_dimension = self.source_column.shape[0]
i1 = array_ops.reshape(
array_ops.tile(
array_ops.expand_dims(math_ops.range(0, batch_size), 1),
[1, source_dimension]),
(-1,))
i2 = array_ops.tile(math_ops.range(0, source_dimension), [batch_size])
# Flatten the bucket indices and unique them across dimensions
# E.g. 2nd dimension indices will range from k to 2*k-1 with k buckets
bucket_indices = (
array_ops.reshape(input_tensor, (-1,)) +
(len(self.boundaries) + 1) * i2)
indices = math_ops.to_int64(array_ops.transpose(array_ops.stack((i1, i2))))
dense_shape = math_ops.to_int64(array_ops.stack(
[batch_size, source_dimension]))
sparse_tensor = sparse_tensor_lib.SparseTensor(
indices=indices,
values=bucket_indices,
dense_shape=dense_shape)
return _CategoricalColumn.IdWeightPair(sparse_tensor, None)
def _create_tuple(shape, value): def _create_tuple(shape, value):
"""Returns a tuple with given shape and filled with value.""" """Returns a tuple with given shape and filled with value."""
if shape: if shape:

View File

@ -65,7 +65,7 @@ class LazyColumnTest(test.TestCase):
def _parse_example_config(self): def _parse_example_config(self):
pass pass
builder = fc._LazyBuilder(features={'a': constant_op.constant([[2], [3.]])}) builder = fc._LazyBuilder(features={'a': [[2], [3.]]})
column = TransformCounter() column = TransformCounter()
self.assertEqual(0, column.num_transform) self.assertEqual(0, column.num_transform)
builder.get(column) builder.get(column)
@ -88,7 +88,7 @@ class LazyColumnTest(test.TestCase):
def _parse_example_config(self): def _parse_example_config(self):
pass pass
builder = fc._LazyBuilder(features={'a': constant_op.constant([[2], [3.]])}) builder = fc._LazyBuilder(features={'a': [[2], [3.]]})
column = Transformer() column = Transformer()
self.assertEqual('Output', builder.get(column)) self.assertEqual('Output', builder.get(column))
self.assertEqual('Output', builder.get(column)) self.assertEqual('Output', builder.get(column))
@ -108,13 +108,13 @@ class LazyColumnTest(test.TestCase):
def _parse_example_config(self): def _parse_example_config(self):
pass pass
features = {'a': constant_op.constant([[2], [3.]])} features = {'a': [[2], [3.]]}
builder = fc._LazyBuilder(features=features) builder = fc._LazyBuilder(features=features)
builder.get(Transformer()) builder.get(Transformer())
self.assertEqual(['a'], list(features.keys())) self.assertEqual(['a'], list(features.keys()))
def test_error_if_feature_is_not_found(self): def test_error_if_feature_is_not_found(self):
builder = fc._LazyBuilder(features={'a': constant_op.constant([[2], [3.]])}) builder = fc._LazyBuilder(features={'a': [[2], [3.]]})
with self.assertRaisesRegexp(ValueError, with self.assertRaisesRegexp(ValueError,
'bbb is not in features dictionary'): 'bbb is not in features dictionary'):
builder.get('bbb') builder.get('bbb')
@ -135,7 +135,7 @@ class LazyColumnTest(test.TestCase):
def _parse_example_config(self): def _parse_example_config(self):
pass pass
builder = fc._LazyBuilder(features={'a': constant_op.constant([[2], [3.]])}) builder = fc._LazyBuilder(features={'a': [[2], [3.]]})
with self.assertRaisesRegexp(ValueError, with self.assertRaisesRegexp(ValueError,
'NotAProperColumn is not supported'): 'NotAProperColumn is not supported'):
builder.get(NotAProperColumn()) builder.get(NotAProperColumn())
@ -145,13 +145,13 @@ class LazyColumnTest(test.TestCase):
class NotAFeatureColumn(object): class NotAFeatureColumn(object):
pass pass
builder = fc._LazyBuilder(features={'a': constant_op.constant([[2], [3.]])}) builder = fc._LazyBuilder(features={'a': [[2], [3.]]})
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
TypeError, '"key" must be either a "str" or "_FeatureColumn".'): TypeError, '"key" must be either a "str" or "_FeatureColumn".'):
builder.get(NotAFeatureColumn()) builder.get(NotAFeatureColumn())
class NumericalColumnTest(test.TestCase): class NumericColumnTest(test.TestCase):
def test_defaults(self): def test_defaults(self):
a = fc.numeric_column('aaa') a = fc.numeric_column('aaa')
@ -273,7 +273,7 @@ class NumericalColumnTest(test.TestCase):
price = fc.numeric_column('price', shape=[2], normalizer_fn=_increment_two) price = fc.numeric_column('price', shape=[2], normalizer_fn=_increment_two)
builder = fc._LazyBuilder({ builder = fc._LazyBuilder({
'price': constant_op.constant([[1., 2.], [5., 6.]]) 'price': [[1., 2.], [5., 6.]]
}) })
output = builder.get(price) output = builder.get(price)
with self.test_session(): with self.test_session():
@ -286,7 +286,7 @@ class NumericalColumnTest(test.TestCase):
price = fc.numeric_column('price', shape=[2], normalizer_fn=_increment_two) price = fc.numeric_column('price', shape=[2], normalizer_fn=_increment_two)
builder = fc._LazyBuilder({ builder = fc._LazyBuilder({
'price': constant_op.constant([[1., 2.], [5., 6.]]) 'price': [[1., 2.], [5., 6.]]
}) })
self.assertEqual(builder.get(price), price._get_dense_tensor(builder)) self.assertEqual(builder.get(price), price._get_dense_tensor(builder))
@ -315,7 +315,7 @@ class NumericalColumnTest(test.TestCase):
def test_make_linear_model(self): def test_make_linear_model(self):
price = fc.numeric_column('price') price = fc.numeric_column('price')
with ops.Graph().as_default(): with ops.Graph().as_default():
features = {'price': constant_op.constant([[1.], [5.]])} features = {'price': [[1.], [5.]]}
predictions = fc.make_linear_model(features, [price]) predictions = fc.make_linear_model(features, [price])
bias = get_linear_model_bias() bias = get_linear_model_bias()
price_var = get_linear_model_column_var(price) price_var = get_linear_model_column_var(price)
@ -327,6 +327,231 @@ class NumericalColumnTest(test.TestCase):
self.assertAllClose([[10.], [50.]], predictions.eval()) self.assertAllClose([[10.], [50.]], predictions.eval())
class BucketizedColumnTest(test.TestCase):
def test_invalid_source_column_type(self):
a = fc.categorical_column_with_hash_bucket('aaa', hash_bucket_size=10)
with self.assertRaisesRegexp(
ValueError,
'source_column must be a column generated with numeric_column'):
fc.bucketized_column(a, boundaries=[0, 1])
def test_invalid_source_column_shape(self):
a = fc.numeric_column('aaa', shape=[2, 3])
with self.assertRaisesRegexp(
ValueError, 'source_column must be one-dimensional column'):
fc.bucketized_column(a, boundaries=[0, 1])
def test_invalid_boundaries(self):
a = fc.numeric_column('aaa')
with self.assertRaisesRegexp(
ValueError, 'boundaries must be a sorted list'):
fc.bucketized_column(a, boundaries=None)
with self.assertRaisesRegexp(
ValueError, 'boundaries must be a sorted list'):
fc.bucketized_column(a, boundaries=1.)
with self.assertRaisesRegexp(
ValueError, 'boundaries must be a sorted list'):
fc.bucketized_column(a, boundaries=[1, 0])
with self.assertRaisesRegexp(
ValueError, 'boundaries must be a sorted list'):
fc.bucketized_column(a, boundaries=[1, 1])
def test_name(self):
a = fc.numeric_column('aaa', dtype=dtypes.int32)
b = fc.bucketized_column(a, boundaries=[0, 1])
self.assertEqual('aaa_bucketized', b.name)
def test_parse_config(self):
a = fc.numeric_column('aaa', shape=[2], dtype=dtypes.int32)
b = fc.bucketized_column(a, boundaries=[0, 1])
self.assertEqual({
'aaa': parsing_ops.FixedLenFeature((2,), dtype=dtypes.int32)
}, b._parse_example_config)
def test_variable_shape(self):
a = fc.numeric_column('aaa', shape=[2], dtype=dtypes.int32)
b = fc.bucketized_column(a, boundaries=[0, 1])
# Column 'aaa` has shape [2] times three buckets -> variable_shape=[2, 3].
self.assertAllEqual((2, 3), b._variable_shape)
def test_num_buckets(self):
a = fc.numeric_column('aaa', shape=[2], dtype=dtypes.int32)
b = fc.bucketized_column(a, boundaries=[0, 1])
# Column 'aaa` has shape [2] times three buckets -> num_buckets=6.
self.assertEqual(6, b._num_buckets)
def test_parse_example(self):
price = fc.numeric_column('price', shape=[2])
bucketized_price = fc.bucketized_column(price, boundaries=[0, 50])
data = example_pb2.Example(features=feature_pb2.Features(
feature={
'price':
feature_pb2.Feature(float_list=feature_pb2.FloatList(
value=[20., 110.]))
}))
features = parsing_ops.parse_example(
serialized=[data.SerializeToString()],
features=bucketized_price._parse_example_config)
self.assertIn('price', features)
with self.test_session():
self.assertAllEqual([[20., 110.]], features['price'].eval())
def test_transform_feature(self):
price = fc.numeric_column('price', shape=[2])
bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
with ops.Graph().as_default():
builder = fc._LazyBuilder({
'price': [[-1., 1.], [5., 6.]]
})
transformed_tensor = builder.get(bucketized_price)
with _initialized_session():
self.assertAllEqual([[0, 1], [3, 4]], transformed_tensor.eval())
def test_get_dense_tensor_one_input_value(self):
"""Tests _get_dense_tensor() for input with shape=[1]."""
price = fc.numeric_column('price', shape=[1])
bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
with ops.Graph().as_default():
builder = fc._LazyBuilder({
'price': [[-1.], [1.], [5.], [6.]]
})
with _initialized_session():
bucketized_price_tensor = bucketized_price._get_dense_tensor(builder)
self.assertAllClose(
# One-hot tensor.
[[[1., 0., 0., 0., 0.]],
[[0., 1., 0., 0., 0.]],
[[0., 0., 0., 1., 0.]],
[[0., 0., 0., 0., 1.]]],
bucketized_price_tensor.eval())
def test_get_dense_tensor_two_input_values(self):
"""Tests _get_dense_tensor() for input with shape=[2]."""
price = fc.numeric_column('price', shape=[2])
bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
with ops.Graph().as_default():
builder = fc._LazyBuilder({
'price': [[-1., 1.], [5., 6.]]
})
with _initialized_session():
bucketized_price_tensor = bucketized_price._get_dense_tensor(builder)
self.assertAllClose(
# One-hot tensor.
[[[1., 0., 0., 0., 0.], [0., 1., 0., 0., 0.]],
[[0., 0., 0., 1., 0.], [0., 0., 0., 0., 1.]]],
bucketized_price_tensor.eval())
def test_get_sparse_tensors_one_input_value(self):
"""Tests _get_sparse_tensors() for input with shape=[1]."""
price = fc.numeric_column('price', shape=[1])
bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
with ops.Graph().as_default():
builder = fc._LazyBuilder({
'price': [[-1.], [1.], [5.], [6.]]
})
with _initialized_session() as sess:
id_weight_pair = bucketized_price._get_sparse_tensors(builder)
self.assertIsNone(id_weight_pair.weight_tensor)
id_tensor_value = sess.run(id_weight_pair.id_tensor)
self.assertAllEqual(
[[0, 0], [1, 0], [2, 0], [3, 0]], id_tensor_value.indices)
self.assertAllEqual([0, 1, 3, 4], id_tensor_value.values)
self.assertAllEqual([4, 1], id_tensor_value.dense_shape)
def test_get_sparse_tensors_two_input_values(self):
"""Tests _get_sparse_tensors() for input with shape=[2]."""
price = fc.numeric_column('price', shape=[2])
bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
with ops.Graph().as_default():
builder = fc._LazyBuilder({
'price': [[-1., 1.], [5., 6.]]
})
with _initialized_session() as sess:
id_weight_pair = bucketized_price._get_sparse_tensors(builder)
self.assertIsNone(id_weight_pair.weight_tensor)
id_tensor_value = sess.run(id_weight_pair.id_tensor)
self.assertAllEqual(
[[0, 0], [0, 1], [1, 0], [1, 1]], id_tensor_value.indices)
# Values 0-4 correspond to the first column of the input price.
# Values 5-9 correspond to the second column of the input price.
self.assertAllEqual([0, 6, 3, 9], id_tensor_value.values)
self.assertAllEqual([2, 2], id_tensor_value.dense_shape)
def test_sparse_tensor_input_not_supported(self):
price = fc.numeric_column('price')
bucketized_price = fc.bucketized_column(price, boundaries=[0, 1])
builder = fc._LazyBuilder({
'price':
sparse_tensor.SparseTensor(
indices=[[0, 0]], values=[0.3], dense_shape=[1, 1])
})
with self.assertRaisesRegexp(ValueError, 'must be a Tensor'):
bucketized_price._transform_feature(builder)
def test_deep_copy(self):
a = fc.numeric_column('aaa', shape=[2])
a_bucketized = fc.bucketized_column(a, boundaries=[0, 1])
a_bucketized_copy = copy.deepcopy(a_bucketized)
self.assertEqual(a_bucketized_copy.name, 'aaa_bucketized')
self.assertAllEqual(a_bucketized_copy._variable_shape, (2, 3))
self.assertEqual(a_bucketized_copy.boundaries, (0, 1))
def test_make_linear_model_one_input_value(self):
"""Tests make_linear_model() for input with shape=[1]."""
price = fc.numeric_column('price', shape=[1])
bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
with ops.Graph().as_default():
features = {'price': [[-1.], [1.], [5.], [6.]]}
predictions = fc.make_linear_model(features, [bucketized_price])
bias = get_linear_model_bias()
bucketized_price_var = get_linear_model_column_var(bucketized_price)
with _initialized_session() as sess:
self.assertAllClose([0.], bias.eval())
# One weight variable per bucket, all initialized to zero.
self.assertAllClose(
[[0.], [0.], [0.], [0.], [0.]], bucketized_price_var.eval())
self.assertAllClose([[0.], [0.], [0.], [0.]], predictions.eval())
sess.run(bucketized_price_var.assign(
[[10.], [20.], [30.], [40.], [50.]]))
# price -1. is in the 0th bucket, whose weight is 10.
# price 1. is in the 1st bucket, whose weight is 20.
# price 5. is in the 3rd bucket, whose weight is 40.
# price 6. is in the 4th bucket, whose weight is 50.
self.assertAllClose([[10.], [20.], [40.], [50.]], predictions.eval())
sess.run(bias.assign([1.]))
self.assertAllClose([[11.], [21.], [41.], [51.]], predictions.eval())
def test_make_linear_model_two_input_values(self):
"""Tests make_linear_model() for input with shape=[2]."""
price = fc.numeric_column('price', shape=[2])
bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
with ops.Graph().as_default():
features = {'price': [[-1., 1.], [5., 6.]]}
predictions = fc.make_linear_model(features, [bucketized_price])
bias = get_linear_model_bias()
bucketized_price_var = get_linear_model_column_var(bucketized_price)
with _initialized_session() as sess:
self.assertAllClose([0.], bias.eval())
# One weight per bucket per input column, all initialized to zero.
self.assertAllClose(
[[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]],
bucketized_price_var.eval())
self.assertAllClose([[0.], [0.]], predictions.eval())
sess.run(bucketized_price_var.assign(
[[10.], [20.], [30.], [40.], [50.],
[60.], [70.], [80.], [90.], [100.]]))
# 1st example:
# price -1. is in the 0th bucket, whose weight is 10.
# price 1. is in the 6th bucket, whose weight is 70.
# 2nd example:
# price 5. is in the 3rd bucket, whose weight is 40.
# price 6. is in the 9th bucket, whose weight is 100.
self.assertAllClose([[80.], [140.]], predictions.eval())
sess.run(bias.assign([1.]))
self.assertAllClose([[81.], [141.]], predictions.eval())
class SparseColumnHashedTest(test.TestCase): class SparseColumnHashedTest(test.TestCase):
def test_defaults(self): def test_defaults(self):
@ -396,15 +621,15 @@ class SparseColumnHashedTest(test.TestCase):
float_fc = fc.categorical_column_with_hash_bucket( float_fc = fc.categorical_column_with_hash_bucket(
'a_float', 10, dtype=dtypes.string) 'a_float', 10, dtype=dtypes.string)
int_tensor = sparse_tensor.SparseTensor( int_tensor = sparse_tensor.SparseTensor(
values=constant_op.constant([101]), values=[101],
indices=[[0, 0]], indices=[[0, 0]],
dense_shape=[1, 1]) dense_shape=[1, 1])
string_tensor = sparse_tensor.SparseTensor( string_tensor = sparse_tensor.SparseTensor(
values=constant_op.constant(['101']), values=['101'],
indices=[[0, 0]], indices=[[0, 0]],
dense_shape=[1, 1]) dense_shape=[1, 1])
float_tensor = sparse_tensor.SparseTensor( float_tensor = sparse_tensor.SparseTensor(
values=constant_op.constant([101.]), values=[101.],
indices=[[0, 0]], indices=[[0, 0]],
dense_shape=[1, 1]) dense_shape=[1, 1])
builder = fc._LazyBuilder({ builder = fc._LazyBuilder({
@ -520,7 +745,7 @@ class MakeLinearModelTest(test.TestCase):
def test_dense_bias(self): def test_dense_bias(self):
price = fc.numeric_column('price') price = fc.numeric_column('price')
with ops.Graph().as_default(): with ops.Graph().as_default():
features = {'price': constant_op.constant([[1.], [5.]])} features = {'price': [[1.], [5.]]}
predictions = fc.make_linear_model(features, [price]) predictions = fc.make_linear_model(features, [price])
bias = get_linear_model_bias() bias = get_linear_model_bias()
price_var = get_linear_model_column_var(price) price_var = get_linear_model_column_var(price)
@ -567,10 +792,63 @@ class MakeLinearModelTest(test.TestCase):
sess.run(price_var.assign([[10.]])) sess.run(price_var.assign([[10.]]))
self.assertAllClose([[1015.], [10065.]], predictions.eval()) self.assertAllClose([[1015.], [10065.]], predictions.eval())
def test_dense_and_sparse_column(self):
"""When the column is both dense and sparse, uses sparse tensors."""
class _DenseAndSparseColumn(fc._DenseColumn, fc._CategoricalColumn):
@property
def name(self):
return 'dense_and_sparse_column'
@property
def _parse_example_config(self):
return {self.name: parsing_ops.VarLenFeature(self.dtype)}
def _transform_feature(self, inputs):
return inputs.get(self.name)
@property
def _variable_shape(self):
raise ValueError('Should not use this method.')
def _get_dense_tensor(self, inputs, weight_collections=None,
trainable=None):
raise ValueError('Should not use this method.')
@property
def _num_buckets(self):
return 4
def _get_sparse_tensors(self, inputs, weight_collections=None,
trainable=None):
sp_tensor = sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 0], [1, 1]],
values=[2, 0, 3],
dense_shape=[2, 2])
return fc._CategoricalColumn.IdWeightPair(sp_tensor, None)
dense_and_sparse_column = _DenseAndSparseColumn()
with ops.Graph().as_default():
sp_tensor = sparse_tensor.SparseTensor(
values=['omar', 'stringer', 'marlo'],
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
features = {dense_and_sparse_column.name: sp_tensor}
predictions = fc.make_linear_model(features, [dense_and_sparse_column])
bias = get_linear_model_bias()
dense_and_sparse_column_var = get_linear_model_column_var(
dense_and_sparse_column)
with _initialized_session() as sess:
sess.run(dense_and_sparse_column_var.assign(
[[10.], [100.], [1000.], [10000.]]))
sess.run(bias.assign([5.]))
self.assertAllClose([[1005.], [10015.]], predictions.eval())
def test_dense_multi_output(self): def test_dense_multi_output(self):
price = fc.numeric_column('price') price = fc.numeric_column('price')
with ops.Graph().as_default(): with ops.Graph().as_default():
features = {'price': constant_op.constant([[1.], [5.]])} features = {'price': [[1.], [5.]]}
predictions = fc.make_linear_model(features, [price], units=3) predictions = fc.make_linear_model(features, [price], units=3)
bias = get_linear_model_bias() bias = get_linear_model_bias()
price_var = get_linear_model_column_var(price) price_var = get_linear_model_column_var(price)
@ -607,7 +885,7 @@ class MakeLinearModelTest(test.TestCase):
def test_dense_multi_dimension(self): def test_dense_multi_dimension(self):
price = fc.numeric_column('price', shape=2) price = fc.numeric_column('price', shape=2)
with ops.Graph().as_default(): with ops.Graph().as_default():
features = {'price': constant_op.constant([[1., 2.], [5., 6.]])} features = {'price': [[1., 2.], [5., 6.]]}
predictions = fc.make_linear_model(features, [price]) predictions = fc.make_linear_model(features, [price])
price_var = get_linear_model_column_var(price) price_var = get_linear_model_column_var(price)
with _initialized_session() as sess: with _initialized_session() as sess:
@ -635,7 +913,7 @@ class MakeLinearModelTest(test.TestCase):
def test_dense_multi_dimension_multi_output(self): def test_dense_multi_dimension_multi_output(self):
price = fc.numeric_column('price', shape=2) price = fc.numeric_column('price', shape=2)
with ops.Graph().as_default(): with ops.Graph().as_default():
features = {'price': constant_op.constant([[1., 2.], [5., 6.]])} features = {'price': [[1., 2.], [5., 6.]]}
predictions = fc.make_linear_model(features, [price], units=3) predictions = fc.make_linear_model(features, [price], units=3)
bias = get_linear_model_bias() bias = get_linear_model_bias()
price_var = get_linear_model_column_var(price) price_var = get_linear_model_column_var(price)
@ -650,7 +928,7 @@ class MakeLinearModelTest(test.TestCase):
def test_raises_if_shape_mismatch(self): def test_raises_if_shape_mismatch(self):
price = fc.numeric_column('price', shape=2) price = fc.numeric_column('price', shape=2)
with ops.Graph().as_default(): with ops.Graph().as_default():
features = {'price': constant_op.constant([[1.], [5.]])} features = {'price': [[1.], [5.]]}
predictions = fc.make_linear_model(features, [price]) predictions = fc.make_linear_model(features, [price])
with _initialized_session(): with _initialized_session():
with self.assertRaisesRegexp(Exception, 'requested shape has 4'): with self.assertRaisesRegexp(Exception, 'requested shape has 4'):
@ -659,7 +937,7 @@ class MakeLinearModelTest(test.TestCase):
def test_dense_reshaping(self): def test_dense_reshaping(self):
price = fc.numeric_column('price', shape=[1, 2]) price = fc.numeric_column('price', shape=[1, 2])
with ops.Graph().as_default(): with ops.Graph().as_default():
features = {'price': constant_op.constant([[[1., 2.]], [[5., 6.]]])} features = {'price': [[[1., 2.]], [[5., 6.]]]}
predictions = fc.make_linear_model(features, [price]) predictions = fc.make_linear_model(features, [price])
bias = get_linear_model_bias() bias = get_linear_model_bias()
price_var = get_linear_model_column_var(price) price_var = get_linear_model_column_var(price)
@ -675,8 +953,8 @@ class MakeLinearModelTest(test.TestCase):
price2 = fc.numeric_column('price2') price2 = fc.numeric_column('price2')
with ops.Graph().as_default(): with ops.Graph().as_default():
features = { features = {
'price1': constant_op.constant([[1., 2.], [5., 6.]]), 'price1': [[1., 2.], [5., 6.]],
'price2': constant_op.constant([[3.], [4.]]) 'price2': [[3.], [4.]]
} }
predictions = fc.make_linear_model(features, [price1, price2]) predictions = fc.make_linear_model(features, [price1, price2])
bias = get_linear_model_bias() bias = get_linear_model_bias()
@ -695,7 +973,7 @@ class MakeLinearModelTest(test.TestCase):
def test_dense_collection(self): def test_dense_collection(self):
price = fc.numeric_column('price') price = fc.numeric_column('price')
with ops.Graph().as_default() as g: with ops.Graph().as_default() as g:
features = {'price': constant_op.constant([[1.], [5.]])} features = {'price': [[1.], [5.]]}
fc.make_linear_model(features, [price], weight_collections=['my-vars']) fc.make_linear_model(features, [price], weight_collections=['my-vars'])
my_vars = g.get_collection('my-vars') my_vars = g.get_collection('my-vars')
bias = get_linear_model_bias() bias = get_linear_model_bias()
@ -720,7 +998,7 @@ class MakeLinearModelTest(test.TestCase):
def test_dense_trainable_default(self): def test_dense_trainable_default(self):
price = fc.numeric_column('price') price = fc.numeric_column('price')
with ops.Graph().as_default() as g: with ops.Graph().as_default() as g:
features = {'price': constant_op.constant([[1.], [5.]])} features = {'price': [[1.], [5.]]}
fc.make_linear_model(features, [price]) fc.make_linear_model(features, [price])
bias = get_linear_model_bias() bias = get_linear_model_bias()
price_var = get_linear_model_column_var(price) price_var = get_linear_model_column_var(price)
@ -744,7 +1022,7 @@ class MakeLinearModelTest(test.TestCase):
def test_dense_trainable_false(self): def test_dense_trainable_false(self):
price = fc.numeric_column('price') price = fc.numeric_column('price')
with ops.Graph().as_default() as g: with ops.Graph().as_default() as g:
features = {'price': constant_op.constant([[1.], [5.]])} features = {'price': [[1.], [5.]]}
fc.make_linear_model(features, [price], trainable=False) fc.make_linear_model(features, [price], trainable=False)
trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
self.assertEqual([], trainable_vars) self.assertEqual([], trainable_vars)
@ -796,5 +1074,89 @@ class MakeLinearModelTest(test.TestCase):
self.assertIn('wire_cast', my_vars[2].name) self.assertIn('wire_cast', my_vars[2].name)
class MakeInputLayerTest(test.TestCase):
def test_should_be_dense_column(self):
with self.assertRaisesRegexp(ValueError, 'must be a _DenseColumn'):
fc.make_input_layer(
features={'a': [[0]]},
feature_columns=[
fc.categorical_column_with_hash_bucket('wire_cast', 4)
])
def test_does_not_support_dict_columns(self):
with self.assertRaisesRegexp(
ValueError, 'Expected feature_columns to be iterable, found dict.'):
fc.make_input_layer(
features={'a': [[0]]}, feature_columns={'a': fc.numeric_column('a')})
def test_raises_if_duplicate_name(self):
with self.assertRaisesRegexp(
ValueError, 'Duplicate feature column name found for columns'):
fc.make_input_layer(
features={'a': [[0]]},
feature_columns=[fc.numeric_column('a'),
fc.numeric_column('a')])
def test_one_column(self):
price = fc.numeric_column('price')
with ops.Graph().as_default():
features = {'price': [[1.], [5.]]}
net = fc.make_input_layer(features, [price])
with _initialized_session():
self.assertAllClose([[1.], [5.]], net.eval())
def test_multi_dimension(self):
price = fc.numeric_column('price', shape=2)
with ops.Graph().as_default():
features = {'price': [[1., 2.], [5., 6.]]}
net = fc.make_input_layer(features, [price])
with _initialized_session():
self.assertAllClose([[1., 2.], [5., 6.]], net.eval())
def test_raises_if_shape_mismatch(self):
price = fc.numeric_column('price', shape=2)
with ops.Graph().as_default():
features = {'price': [[1.], [5.]]}
net = fc.make_input_layer(features, [price])
with _initialized_session():
with self.assertRaisesRegexp(Exception, 'requested shape has 4'):
net.eval()
def test_reshaping(self):
price = fc.numeric_column('price', shape=[1, 2])
with ops.Graph().as_default():
features = {'price': [[[1., 2.]], [[5., 6.]]]}
net = fc.make_input_layer(features, [price])
with _initialized_session():
self.assertAllClose([[1., 2.], [5., 6.]], net.eval())
def test_multi_column(self):
price1 = fc.numeric_column('price1', shape=2)
price2 = fc.numeric_column('price2')
with ops.Graph().as_default():
features = {
'price1': [[1., 2.], [5., 6.]],
'price2': [[3.], [4.]]
}
net = fc.make_input_layer(features, [price1, price2])
with _initialized_session():
self.assertAllClose([[1., 2., 3.], [5., 6., 4.]], net.eval())
def test_column_order(self):
price_a = fc.numeric_column('price_a')
price_b = fc.numeric_column('price_b')
with ops.Graph().as_default():
features = {
'price_a': [[1.]],
'price_b': [[3.]],
}
net1 = fc.make_input_layer(features, [price_a, price_b])
net2 = fc.make_input_layer(features, [price_b, price_a])
with _initialized_session():
self.assertAllClose([[1., 3.]], net1.eval())
self.assertAllClose([[1., 3.]], net2.eval())
if __name__ == '__main__': if __name__ == '__main__':
test.main() test.main()

View File

@ -422,14 +422,15 @@ def import_scoped_meta_graph(meta_graph_or_file,
graph=None, graph=None,
import_scope=None, import_scope=None,
input_map=None, input_map=None,
unbound_inputs_col_name="unbound_inputs"): unbound_inputs_col_name="unbound_inputs",
"""Recreates a`Graph` saved in a `MetaGraphDef` proto. restore_collections_predicate=(lambda key: True)):
"""Recreates a `Graph` saved in a `MetaGraphDef` proto.
This function takes a `MetaGraphDef` protocol buffer as input. If This function takes a `MetaGraphDef` protocol buffer as input. If
the argument is a file containing a `MetaGraphDef` protocol buffer , the argument is a file containing a `MetaGraphDef` protocol buffer ,
it constructs a protocol buffer from the file content. The function it constructs a protocol buffer from the file content. The function
then adds all the nodes from the `graph_def` field to the then adds all the nodes from the `graph_def` field to the
current graph, recreates all the collections, and returns a saver current graph, recreates the desired collections, and returns a saver
constructed from the `saver_def` field. constructed from the `saver_def` field.
In combination with `export_scoped_meta_graph()`, this function can be used to In combination with `export_scoped_meta_graph()`, this function can be used to
@ -453,6 +454,10 @@ def import_scoped_meta_graph(meta_graph_or_file,
`Tensor` objects. The values of the named input tensors in the imported `Tensor` objects. The values of the named input tensors in the imported
graph will be re-mapped to the respective `Tensor` values. graph will be re-mapped to the respective `Tensor` values.
unbound_inputs_col_name: Collection name for looking up unbound inputs. unbound_inputs_col_name: Collection name for looking up unbound inputs.
restore_collections_predicate: a predicate on collection names. A collection
named c (i.e whose key is c) will be restored iff
1) `restore_collections_predicate(c)` is True, and
2) `c != unbound_inputs_col_name`.
Returns: Returns:
A dictionary of all the `Variables` imported into the name scope. A dictionary of all the `Variables` imported into the name scope.
@ -503,6 +508,8 @@ def import_scoped_meta_graph(meta_graph_or_file,
# Don't add unbound_inputs to the new graph. # Don't add unbound_inputs to the new graph.
if key == unbound_inputs_col_name: if key == unbound_inputs_col_name:
continue continue
if not restore_collections_predicate(key):
continue
kind = col_def.WhichOneof("kind") kind = col_def.WhichOneof("kind")
if kind is None: if kind is None:

View File

@ -335,6 +335,66 @@ class ScopedMetaGraphTest(test.TestCase):
for a, b in zip(orig_meta_graphs, new_meta_graphs): for a, b in zip(orig_meta_graphs, new_meta_graphs):
test_util.assert_meta_graph_protos_equal(self, a, b) test_util.assert_meta_graph_protos_equal(self, a, b)
def testScopedImportWithSelectedCollections(self):
meta_graph_filename = os.path.join(
_TestDir("selected_collections_import"), "meta_graph.pb")
graph = ops.Graph()
# Add a variable to populate two collections. The functionality tested is
# not specific to variables, but using variables in the test is convenient.
with graph.as_default():
variables.Variable(initial_value=1.0, trainable=True)
self.assertTrue(
all([
graph.get_collection(key)
for key in
[ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.TRAINABLE_VARIABLES]
]))
meta_graph.export_scoped_meta_graph(
filename=meta_graph_filename, graph=graph)
def _test_import(include_collection_keys, omit_collection_keys):
assert set(include_collection_keys).isdisjoint(omit_collection_keys)
newgraph = ops.Graph()
import_scope = "some_scope_name"
def _restore_collections_predicate(collection_key):
return (collection_key in include_collection_keys and
collection_key not in omit_collection_keys)
meta_graph.import_scoped_meta_graph(
meta_graph_filename,
graph=newgraph,
import_scope=import_scope,
restore_collections_predicate=_restore_collections_predicate)
collection_values = [
newgraph.get_collection(name=key, scope=import_scope)
for key in include_collection_keys
]
self.assertTrue(all(collection_values))
collection_values = [
newgraph.get_collection(name=key, scope=import_scope)
for key in omit_collection_keys
]
self.assertFalse(any(collection_values))
_test_import(
include_collection_keys=[
ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.TRAINABLE_VARIABLES
],
omit_collection_keys=[])
_test_import(
include_collection_keys=[ops.GraphKeys.GLOBAL_VARIABLES],
omit_collection_keys=[ops.GraphKeys.TRAINABLE_VARIABLES])
_test_import(
include_collection_keys=[ops.GraphKeys.TRAINABLE_VARIABLES],
omit_collection_keys=[ops.GraphKeys.GLOBAL_VARIABLES])
_test_import(
include_collection_keys=[],
omit_collection_keys=[
ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.TRAINABLE_VARIABLES
])
def _testScopedExportWithQueue(self, test_dir, exported_filename): def _testScopedExportWithQueue(self, test_dir, exported_filename):
graph = ops.Graph() graph = ops.Graph()
with graph.as_default(): with graph.as_default():

View File

@ -113,10 +113,9 @@ class DepthwiseConv2DTest(test.TestCase):
total_size_1 *= s total_size_1 *= s
for s in filter_in_sizes: for s in filter_in_sizes:
total_size_2 *= s total_size_2 *= s
# Initializes the input tensor with array containing incrementing # Initializes the input and filter tensor with numbers incrementing from 1.
# numbers from 1.
x1 = [f * 1.0 for f in range(1, total_size_1 + 1)] x1 = [f * 1.0 for f in range(1, total_size_1 + 1)]
x2 = [1.0 for f in range(1, total_size_2 + 1)] x2 = [f * 1.0 for f in range(1, total_size_2 + 1)]
with self.test_session(use_gpu=use_gpu) as sess: with self.test_session(use_gpu=use_gpu) as sess:
t1 = constant_op.constant(x1, shape=tensor_in_sizes) t1 = constant_op.constant(x1, shape=tensor_in_sizes)
t1.set_shape(tensor_in_sizes) t1.set_shape(tensor_in_sizes)
@ -147,8 +146,9 @@ class DepthwiseConv2DTest(test.TestCase):
native_result = sess.run(conv_native) native_result = sess.run(conv_native)
interface_result = sess.run(conv_interface) interface_result = sess.run(conv_interface)
print("diff matrix:", print("depthwise conv_2d: ", tensor_in_sizes, "*", filter_in_sizes,
np.amax(np.ravel(native_result) - np.ravel(interface_result))) ", stride:", stride, ", padding: ", padding, ", max diff: ",
np.amax(np.absolute(native_result - interface_result)))
self.assertArrayNear( self.assertArrayNear(
np.ravel(native_result), np.ravel(interface_result), 1e-5) np.ravel(native_result), np.ravel(interface_result), 1e-5)
self.assertShapeEqual(native_result, conv_native) self.assertShapeEqual(native_result, conv_native)

View File

@ -88,6 +88,7 @@ class SparseAddTest(test.TestCase):
for sp_a in (self._SparseTensorValue_3x3(), self._SparseTensor_3x3()): for sp_a in (self._SparseTensorValue_3x3(), self._SparseTensor_3x3()):
for sp_b in (self._SparseTensorValue_3x3(), self._SparseTensor_3x3()): for sp_b in (self._SparseTensorValue_3x3(), self._SparseTensor_3x3()):
sp_sum = sparse_ops.sparse_add(sp_a, sp_b) sp_sum = sparse_ops.sparse_add(sp_a, sp_b)
self.assertAllEqual((3, 3), sp_sum.get_shape())
sum_out = sess.run(sp_sum) sum_out = sess.run(sp_sum)

View File

@ -328,6 +328,12 @@ class SparseResetShapeTest(test_util.TensorFlowTestCase):
return sparse_tensor.SparseTensorValue(self._IND_2_5_6, self._VAL_2_5_6, return sparse_tensor.SparseTensorValue(self._IND_2_5_6, self._VAL_2_5_6,
self._SHP_2_5_6) self._SHP_2_5_6)
def testStaticShapeInfoPreservedWhenNewShapeIsProvidedAndStatic(self):
sp_input = self._SparseTensor_2x5x6()
new_shape = np.array([3, 6, 7], dtype=np.int64)
sp_output = sparse_ops.sparse_reset_shape(sp_input, new_shape)
self.assertAllEqual([3, 6, 7], sp_output.get_shape())
def testBasic(self): def testBasic(self):
with self.test_session(use_gpu=False) as sess: with self.test_session(use_gpu=False) as sess:
sp_input = self._SparseTensor_2x5x6() sp_input = self._SparseTensor_2x5x6()
@ -397,14 +403,21 @@ class SparseResetShapeTest(test_util.TensorFlowTestCase):
with self.assertRaisesOpError("x == y did not hold element-wise"): with self.assertRaisesOpError("x == y did not hold element-wise"):
sess.run(out, feed_dict={new_shape: np.array([3, 7], dtype=np.int64)}) sess.run(out, feed_dict={new_shape: np.array([3, 7], dtype=np.int64)})
def testInvalidDimensionSize(self): def testInvalidDimensionSizeStatic(self):
sp_input = self._SparseTensor_2x5x6()
new_shape = np.array([3, 7, 5], dtype=np.int64)
with self.assertRaisesRegexp(ValueError, "should have dimension sizes"):
sparse_ops.sparse_reset_shape(sp_input, new_shape)
def testInvalidDimensionSizeDynamic(self):
with self.test_session(use_gpu=False) as sess: with self.test_session(use_gpu=False) as sess:
sp_input = self._SparseTensor_2x5x6() sp_input = self._SparseTensor_2x5x6()
new_shape = np.array([3, 7, 5], dtype=np.int64) new_shape = array_ops.placeholder(dtype=dtypes.int32)
out = sparse_ops.sparse_reset_shape(sp_input, new_shape) out = sparse_ops.sparse_reset_shape(sp_input, new_shape)
with self.assertRaisesOpError("x <= y did not hold element-wise"): with self.assertRaisesOpError("x <= y did not hold element-wise"):
sess.run(out) sess.run(out, feed_dict={new_shape: [3, 7, 5]})
def testInvalidDimensionSizeInputUnavailableInGraphConstruction(self): def testInvalidDimensionSizeInputUnavailableInGraphConstruction(self):
sp_input = array_ops.sparse_placeholder(dtype=dtypes.int32) sp_input = array_ops.sparse_placeholder(dtype=dtypes.int32)

View File

@ -48,6 +48,13 @@ class SparseReorderTest(test.TestCase):
shape = np.array([5, 6]).astype(np.int64) shape = np.array([5, 6]).astype(np.int64)
return sparse_tensor.SparseTensorValue(ind, val, shape) return sparse_tensor.SparseTensorValue(ind, val, shape)
def testStaticShapeInfoPreserved(self):
sp_input = sparse_tensor.SparseTensor.from_value(
self._SparseTensorValue_5x6(np.arange(6)))
self.assertAllEqual((5, 6), sp_input.get_shape())
sp_output = sparse_ops.sparse_reorder(sp_input)
self.assertAllEqual((5, 6), sp_output.get_shape())
def testAlreadyInOrder(self): def testAlreadyInOrder(self):
with self.test_session(use_gpu=False) as sess: with self.test_session(use_gpu=False) as sess:
input_val = self._SparseTensorValue_5x6(np.arange(6)) input_val = self._SparseTensorValue_5x6(np.arange(6))

View File

@ -50,6 +50,13 @@ class SparseReshapeTest(test.TestCase):
shape = np.array([2, 3, 4]) shape = np.array([2, 3, 4])
return sparse_tensor.SparseTensorValue(ind, val, shape) return sparse_tensor.SparseTensorValue(ind, val, shape)
def testStaticShapeInfoPreserved(self):
sp_input = sparse_tensor.SparseTensor.from_value(
self._SparseTensorValue_5x6())
self.assertAllEqual((5, 6), sp_input.get_shape())
sp_output = sparse_ops.sparse_reshape(sp_input, shape=(1, 5, 2, 3))
self.assertAllEqual((1, 5, 2, 3), sp_output.get_shape())
def testSameShape(self): def testSameShape(self):
with self.test_session(use_gpu=False) as sess: with self.test_session(use_gpu=False) as sess:
input_val = self._SparseTensorValue_5x6() input_val = self._SparseTensorValue_5x6()
@ -180,6 +187,12 @@ class SparseReshapeTest(test.TestCase):
with self.assertRaisesOpError("only one output shape size may be -1"): with self.assertRaisesOpError("only one output shape size may be -1"):
sess.run(sp_output, {sp_input: input_val}) sess.run(sp_output, {sp_input: input_val})
def testProvideStaticallyMismatchedSizes(self):
input_val = self._SparseTensorValue_5x6()
sp_input = sparse_tensor.SparseTensor.from_value(input_val)
with self.assertRaisesRegexp(ValueError, "Cannot reshape"):
sparse_ops.sparse_reshape(sp_input, [4, 7])
def testFeedMismatchedSizes(self): def testFeedMismatchedSizes(self):
with self.test_session(use_gpu=False) as sess: with self.test_session(use_gpu=False) as sess:
sp_input = self._SparseTensorPlaceholder() sp_input = self._SparseTensorPlaceholder()

View File

@ -774,6 +774,11 @@ class VariableScopeTest(test.TestCase):
self.assertEqual([v.name self.assertEqual([v.name
for v in scope.global_variables()], ["foo/b:0"]) for v in scope.global_variables()], ["foo/b:0"])
def testGetVariableWithRefDtype(self):
v = variable_scope.get_variable("v", shape=[3, 4], dtype=dtypes.float32)
# Ensure it is possible to do get_variable with a _ref dtype passed in.
_ = variable_scope.get_variable("w", shape=[5, 6], dtype=v.dtype)
def axis0_into1_partitioner(shape=None, **unused_kwargs): def axis0_into1_partitioner(shape=None, **unused_kwargs):
part = [1] * len(shape) part = [1] * len(shape)

View File

@ -335,7 +335,7 @@ class Layer(object):
def add_variable(self, name, shape, dtype=None, def add_variable(self, name, shape, dtype=None,
initializer=None, regularizer=None, trainable=True): initializer=None, regularizer=None, trainable=True):
"""Adds a new variable to the layer. """Adds a new variable to the layer, or gets an existing one; returns it.
Arguments: Arguments:
name: variable name. name: variable name.
@ -424,7 +424,6 @@ class Layer(object):
self.build(input_shapes[0]) self.build(input_shapes[0])
else: else:
self.build(input_shapes) self.build(input_shapes)
self.built = True
if 'scope' in tf_inspect.getargspec(self.call).args: if 'scope' in tf_inspect.getargspec(self.call).args:
kwargs['scope'] = scope kwargs['scope'] = scope
outputs = self.call(inputs, *args, **kwargs) outputs = self.call(inputs, *args, **kwargs)
@ -443,6 +442,7 @@ class Layer(object):
# Update global default collections. # Update global default collections.
_add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS) _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS)
self.built = True
return outputs return outputs
@property @property

View File

@ -153,6 +153,36 @@ class BaseLayerTest(test.TestCase):
self.assertEqual(layer.built, True) self.assertEqual(layer.built, True)
self.assertEqual(outputs.op.name, 'my_layer/Square') self.assertEqual(outputs.op.name, 'my_layer/Square')
def testFirstCallCanCreateVariablesButSecondCanNotWhenBuildEmpty(self):
class MyLayer(base_layers.Layer):
def build(self, _):
# Do not mark the layer as built.
pass
def call(self, inputs):
self.my_var = self.add_variable('my_var', [2, 2])
if self.built:
# Skip creating on the first call; try to create after it's
# built. This is expected to fail.
self.add_variable('this_will_break_on_second_call', [2, 2])
return inputs + math_ops.square(self.my_var)
layer = MyLayer(name='my_layer')
inputs = random_ops.random_uniform((2,), seed=1)
outputs = layer.apply(inputs)
self.assertEqual(layer.built, True)
self.assertEqual(outputs.op.name, 'my_layer/add')
self.assertListEqual(
[v.name for v in layer.variables], ['my_layer/my_var:0'])
with self.assertRaisesRegexp(ValueError,
'my_layer/this_will_break_on_second_call'):
layer.apply(inputs)
# The list of variables hasn't changed.
self.assertListEqual(
[v.name for v in layer.variables], ['my_layer/my_var:0'])
def testDeepCopy(self): def testDeepCopy(self):
class MyLayer(base_layers.Layer): class MyLayer(base_layers.Layer):

View File

@ -145,6 +145,7 @@ class _Conv(base.Layer):
dtype=self.dtype) dtype=self.dtype)
else: else:
self.bias = None self.bias = None
self.built = True
def call(self, inputs): def call(self, inputs):
outputs = nn.convolution( outputs = nn.convolution(
@ -837,6 +838,7 @@ class SeparableConv2D(Conv2D):
dtype=self.dtype) dtype=self.dtype)
else: else:
self.bias = None self.bias = None
self.built = True
def call(self, inputs): def call(self, inputs):
if self.data_format == 'channels_first': if self.data_format == 'channels_first':
@ -1070,6 +1072,7 @@ class Conv2DTranspose(Conv2D):
dtype=self.dtype) dtype=self.dtype)
else: else:
self.bias = None self.bias = None
self.built = True
def call(self, inputs): def call(self, inputs):
inputs_shape = array_ops.shape(inputs) inputs_shape = array_ops.shape(inputs)
@ -1297,6 +1300,7 @@ class Conv3DTranspose(Conv3D):
dtype=self.dtype) dtype=self.dtype)
else: else:
self.bias = None self.bias = None
self.built = True
def call(self, inputs): def call(self, inputs):
inputs_shape = array_ops.shape(inputs) inputs_shape = array_ops.shape(inputs)

View File

@ -130,6 +130,7 @@ class Dense(base.Layer):
trainable=True) trainable=True)
else: else:
self.bias = None self.bias = None
self.built = True
def call(self, inputs): def call(self, inputs):
inputs = ops.convert_to_tensor(inputs, dtype=self.dtype) inputs = ops.convert_to_tensor(inputs, dtype=self.dtype)

View File

@ -201,6 +201,7 @@ class BatchNormalization(base.Layer):
'renorm_stddev_weight', ()) 'renorm_stddev_weight', ())
finally: finally:
self._scope.set_partitioner(partitioner) self._scope.set_partitioner(partitioner)
self.built = True
def _renorm_correction_and_moments(self, mean, variance, training): def _renorm_correction_and_moments(self, mean, variance, training):
"""Returns the correction and update values for renorm.""" """Returns the correction and update values for renorm."""
@ -399,7 +400,9 @@ def batch_normalization(inputs,
training: Either a Python boolean, or a TensorFlow boolean scalar tensor training: Either a Python boolean, or a TensorFlow boolean scalar tensor
(e.g. a placeholder). Whether to return the output in training mode (e.g. a placeholder). Whether to return the output in training mode
(normalized with statistics of the current batch) or in inference mode (normalized with statistics of the current batch) or in inference mode
(normalized with moving statistics). (normalized with moving statistics). **NOTE**: make sure to set this
parameter correctly, or else your training/inference will not work
properly.
trainable: Boolean, if `True` also add variables to the graph collection trainable: Boolean, if `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
name: String, the name of the layer. name: String, the name of the layer.

View File

@ -71,6 +71,7 @@ class _Pooling1D(base.Layer):
if len(input_shape) != 3: if len(input_shape) != 3:
raise ValueError('Inputs should have rank 3. ' raise ValueError('Inputs should have rank 3. '
'Received input shape:', str(input_shape)) 'Received input shape:', str(input_shape))
self.built = True
def call(self, inputs): def call(self, inputs):
# There is no TF op for 1D pooling, hence we make the inputs 4D. # There is no TF op for 1D pooling, hence we make the inputs 4D.
@ -261,6 +262,7 @@ class _Pooling2D(base.Layer):
if len(input_shape) != 4: if len(input_shape) != 4:
raise ValueError('Inputs should have rank 4. ' raise ValueError('Inputs should have rank 4. '
'Received input shape:', str(input_shape)) 'Received input shape:', str(input_shape))
self.built = True
def call(self, inputs): def call(self, inputs):
if self.data_format == 'channels_last': if self.data_format == 'channels_last':
@ -448,6 +450,7 @@ class _Pooling3D(base.Layer):
if len(input_shape) != 5: if len(input_shape) != 5:
raise ValueError('Inputs should have rank 5. ' raise ValueError('Inputs should have rank 5. '
'Received input shape:', str(input_shape)) 'Received input shape:', str(input_shape))
self.built = True
def call(self, inputs): def call(self, inputs):
pool_shape = (1,) + self.pool_size + (1,) pool_shape = (1,) + self.pool_size + (1,)

View File

@ -21,7 +21,6 @@ from __future__ import print_function
import collections import collections
import hashlib import hashlib
import re
import threading import threading
import six import six
@ -56,6 +55,7 @@ def _as_type_list(dtypes):
def _as_shape_list(shapes, dtypes, unknown_dim_allowed=False, def _as_shape_list(shapes, dtypes, unknown_dim_allowed=False,
unknown_rank_allowed=False): unknown_rank_allowed=False):
"""Convert shapes to a list of tuples of int (or None).""" """Convert shapes to a list of tuples of int (or None)."""
del dtypes
if unknown_dim_allowed: if unknown_dim_allowed:
if (not isinstance(shapes, collections.Sequence) if (not isinstance(shapes, collections.Sequence)
or not shapes or not shapes
@ -925,16 +925,18 @@ class Barrier(object):
If barrier has no completed elements, this operation will block If barrier has no completed elements, this operation will block
until there are 'num_elements' elements to take. until there are 'num_elements' elements to take.
TODO(b/25743580): the semantics of `allow_small_batch` are experimental
and may be extended to other cases in the future.
TODO(ebrevdo): If a take_many(allow_small_batch=True) is blocking
already when the barrier is closed, it will block for ever. Fix this
by using asynchronous operations.
Args: Args:
num_elements: The number of elements to take. num_elements: The number of elements to take.
allow_small_batch: If the barrier is closed, don't block if there are less allow_small_batch: If the barrier is closed, don't block if there are less
completed elements than requested, but instead return all available completed elements than requested, but instead return all available
completed elements. completed elements.
TODO(b/25743580): the semantics of `allow_small_batch` are experimental
and may be extended to other cases in the future.
TODO(ebrevdo): If a take_many(allow_small_batch=True) is blocking
already when the barrier is closed, it will block for ever. Fix this
by using asynchronous operations.
timeout: This specifies the number of milliseconds to block timeout: This specifies the number of milliseconds to block
before returning with DEADLINE_EXCEEDED. (This option is not before returning with DEADLINE_EXCEEDED. (This option is not
supported yet.) supported yet.)

View File

@ -51,6 +51,7 @@ import numpy as np
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
@ -288,12 +289,21 @@ def sparse_add(a, b, thresh=0):
if all(isinstance(inp, sparse_classes) for inp in [a, b]): if all(isinstance(inp, sparse_classes) for inp in [a, b]):
a = _convert_to_sparse_tensor(a) a = _convert_to_sparse_tensor(a)
b = _convert_to_sparse_tensor(b)
thresh = ops.convert_to_tensor( thresh = ops.convert_to_tensor(
thresh, dtype=a.values.dtype.real_dtype, name="thresh") thresh, dtype=a.values.dtype.real_dtype, name="thresh")
output_ind, output_val, output_shape = (gen_sparse_ops._sparse_add( output_ind, output_val, output_shape = (gen_sparse_ops._sparse_add(
a.indices, a.values, a.dense_shape, a.indices, a.values, a.dense_shape,
b.indices, b.values, b.dense_shape, b.indices, b.values, b.dense_shape,
thresh)) thresh))
# Attempt to get output_shape statically.
a.get_shape().assert_is_compatible_with(b.get_shape())
static_shape = array_ops.broadcast_static_shape(
a.get_shape(), b.get_shape())
if static_shape.is_fully_defined():
output_shape = static_shape.as_list()
return sparse_tensor.SparseTensor(output_ind, output_val, output_shape) return sparse_tensor.SparseTensor(output_ind, output_val, output_shape)
else: else:
# swap to make `a` the SparseTensor. # swap to make `a` the SparseTensor.
@ -368,8 +378,12 @@ def sparse_reorder(sp_input, name=None):
reordered_ind, reordered_val = (gen_sparse_ops._sparse_reorder( reordered_ind, reordered_val = (gen_sparse_ops._sparse_reorder(
sp_input.indices, sp_input.values, sp_input.dense_shape, name=name)) sp_input.indices, sp_input.values, sp_input.dense_shape, name=name))
return sparse_tensor.SparseTensor(reordered_ind, reordered_val, if sp_input.get_shape().is_fully_defined():
array_ops.identity(sp_input.dense_shape)) dense_shape = sp_input.get_shape().as_list()
else:
dense_shape = array_ops.identity(sp_input.dense_shape)
return sparse_tensor.SparseTensor(reordered_ind, reordered_val, dense_shape)
def sparse_reshape(sp_input, shape, name=None): def sparse_reshape(sp_input, shape, name=None):
@ -416,13 +430,30 @@ def sparse_reshape(sp_input, shape, name=None):
Raises: Raises:
TypeError: If `sp_input` is not a `SparseTensor`. TypeError: If `sp_input` is not a `SparseTensor`.
ValueError: If argument `shape` requests a `SparseTensor` with a different
number of elements than `sp_input`.
""" """
sp_input = _convert_to_sparse_tensor(sp_input) sp_input = _convert_to_sparse_tensor(sp_input)
shape = ops.convert_to_tensor(shape, dtype=dtypes.int64)
with ops.name_scope(name, "SparseReshape", [sp_input]) as name: with ops.name_scope(name, "SparseReshape", [sp_input]) as name:
reshaped_ind, reshaped_shape = gen_sparse_ops._sparse_reshape( reshaped_ind, reshaped_shape = gen_sparse_ops._sparse_reshape(
sp_input.indices, sp_input.dense_shape, shape, name=name) sp_input.indices, sp_input.dense_shape, shape, name=name)
reshaped_shape_const = tensor_util.constant_value(shape)
if (reshaped_shape_const is not None
and sp_input.get_shape().is_fully_defined()):
# Don't deal with inferred dimensions. That would add significant code.
if all(n >= 0 for n in reshaped_shape_const):
reshaped_size = np.prod(reshaped_shape_const)
in_shape_size = np.prod(sp_input.get_shape().as_list())
if reshaped_size != in_shape_size:
raise ValueError(
"Cannot reshape a tensor with %d elements to shape %s "
"(%d elements)."
% (in_shape_size, reshaped_shape_const, reshaped_size))
reshaped_shape = reshaped_shape_const
return sparse_tensor.SparseTensor( return sparse_tensor.SparseTensor(
reshaped_ind, array_ops.identity(sp_input.values), reshaped_ind, array_ops.identity(sp_input.values),
reshaped_shape) reshaped_shape)
@ -986,6 +1017,8 @@ def sparse_reset_shape(sp_input, new_shape=None):
TypeError: If `sp_input` is not a `SparseTensor`. TypeError: If `sp_input` is not a `SparseTensor`.
ValueError: If `new_shape` represents a tensor with a different rank from ValueError: If `new_shape` represents a tensor with a different rank from
that of `sp_input` (if shapes are known when graph is constructed). that of `sp_input` (if shapes are known when graph is constructed).
ValueError: If `new_shape` is determined during graph build to have
dimension sizes that are too small.
OpError: OpError:
- If `new_shape` has dimension sizes that are too small. - If `new_shape` has dimension sizes that are too small.
- If shapes are not known during graph construction time, and during run - If shapes are not known during graph construction time, and during run
@ -1009,14 +1042,27 @@ def sparse_reset_shape(sp_input, new_shape=None):
# error before the sparse_tensor.SparseTensor catches it. # error before the sparse_tensor.SparseTensor catches it.
output_shape_tensor.get_shape()[0].merge_with(in_shape.get_shape()[0]) output_shape_tensor.get_shape()[0].merge_with(in_shape.get_shape()[0])
# For cases where shape is not known during graph construction. output_shape_tensor_const = tensor_util.constant_value(
output_shape_tensor = control_flow_ops.with_dependencies(
[check_ops.assert_equal(
array_ops.shape(in_shape), array_ops.shape(output_shape_tensor))],
output_shape_tensor)
output_shape_tensor = control_flow_ops.with_dependencies(
[check_ops.assert_less_equal(in_shape, output_shape_tensor)],
output_shape_tensor) output_shape_tensor)
# For cases where all shapes are known during graph construction
if (output_shape_tensor_const is not None
and sp_input.get_shape().is_fully_defined()):
in_shape_const = np.array(sp_input.get_shape().as_list())
if not np.all(in_shape_const <= output_shape_tensor_const):
raise ValueError(
"Requested new_shape should have dimension sizes >= sp_input.shape."
" Found new_shape (%s), sp_input.shape (%s)."
% (in_shape_const, output_shape_tensor_const))
output_shape_tensor = output_shape_tensor_const
else:
# For cases where shape is not known during graph construction.
output_shape_tensor = control_flow_ops.with_dependencies(
[check_ops.assert_equal(
array_ops.shape(in_shape), array_ops.shape(output_shape_tensor))],
output_shape_tensor)
output_shape_tensor = control_flow_ops.with_dependencies(
[check_ops.assert_less_equal(in_shape, output_shape_tensor)],
output_shape_tensor)
return sparse_tensor.SparseTensor(in_indices, in_values, output_shape_tensor) return sparse_tensor.SparseTensor(in_indices, in_values, output_shape_tensor)

View File

@ -280,6 +280,17 @@ class _VariableStore(object):
raise ValueError( raise ValueError(
"Passed a custom_getter which is not callable: %s" % custom_getter) "Passed a custom_getter which is not callable: %s" % custom_getter)
# If a *_ref type is passed in an error would be triggered further down the
# stack. We prevent this using base_dtype to get a non-ref version of the
# type, before doing anything else. When _ref types are removed in favour of
# resources, this line can be removed.
try:
dtype = dtype.base_dtype
except AttributeError:
# .base_dtype not existing means that we will try and use the raw dtype
# which was passed in - this might be a NumPy type which is valid.
pass
# This is the main logic of get_variable. However, custom_getter # This is the main logic of get_variable. However, custom_getter
# may override this logic. So we save it as a callable and pass # may override this logic. So we save it as a callable and pass
# it to custom_getter. # it to custom_getter.
@ -1281,7 +1292,7 @@ def _pure_variable_scope(name_or_scope,
well-defined semantics. Defaults to False (will later change to True). well-defined semantics. Defaults to False (will later change to True).
Yields: Yields:
A scope that can be to captured and reused. A scope that can be captured and reused.
Raises: Raises:
ValueError: when trying to reuse within a create scope, or create within ValueError: when trying to reuse within a create scope, or create within

View File

@ -56,20 +56,22 @@ Example output:
To show all available information in the SavedModel: To show all available information in the SavedModel:
$saved_model_cli show --dir /tmp/saved_model --all $saved_model_cli show --dir /tmp/saved_model --all
'run' command usage: saved_model_cli run [-h] --dir DIR --tag_set TAG_SET usage: saved_model_cli run [-h] --dir DIR --tag_set TAG_SET --signature_def
--signature_def SIGNATURE_DEF_KEY --inputs INPUTS SIGNATURE_DEF_KEY [--inputs INPUTS]
[--outdir OUTDIR] [--overwrite] [--input_exprs INPUT_EXPRS] [--outdir OUTDIR]
[--overwrite] [--tf_debug]
Examples: Examples:
To run input tensors from files through a MetaGraphDef and save the output To run input tensors from files through a MetaGraphDef and save the output
tensors to files: tensors to files:
$saved_model_cli run --dir /tmp/saved_model --tag_set serve $saved_model_cli run --dir /tmp/saved_model --tag_set serve
--signature_def serving_default --inputs x:0=/tmp/124.npz,x2=/tmp/123.npy --signature_def serving_default --inputs x=/tmp/124.npz
--outdir /tmp/out --input_exprs 'x2=np.ones((6,2))' --outdir /tmp/out
To observe the intermediate Tensor values in the runtime graph, use the To observe the intermediate Tensor values in the runtime graph, use the
--tf_debug flag, e.g.: --tf_debug flag, e.g.:
$saved_model_cli run --dir /tmp/saved_model --tag_set serve $saved_model_cli run --dir /tmp/saved_model --tag_set serve
--signature_def serving_default --inputs x:0=/tmp/124.npz,x2=/tmp/123.npy --signature_def serving_default --inputs 'x=/tmp/124.npz;x2=/tmp/123.npy'
--outdir /tmp/out --tf_debug --outdir /tmp/out --tf_debug
To build this tool from source, run: To build this tool from source, run:
@ -367,7 +369,7 @@ def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key,
output_full_path)) output_full_path))
def preprocess_input_arg_string(inputs_str): def preprocess_inputs_arg_string(inputs_str):
"""Parses input arg into dictionary that maps input to file/variable tuple. """Parses input arg into dictionary that maps input to file/variable tuple.
Parses input string in the format of, for example, Parses input string in the format of, for example,
@ -375,74 +377,94 @@ def preprocess_input_arg_string(inputs_str):
dictionary looks like dictionary looks like
{'input_key1': (filename1, variable_name1), {'input_key1': (filename1, variable_name1),
'input_key2': (file2, None)} 'input_key2': (file2, None)}
, which maps input keys to a tuple of file name and varaible name(None if , which maps input keys to a tuple of file name and variable name(None if
empty). empty).
Args: Args:
inputs_str: A string that specified where to load inputs. Each input is inputs_str: A string that specified where to load inputs. Inputs are
separated by comma. separated by semicolons.
* If the command line arg for inputs is quoted and contains
whitespace(s), all whitespaces will be ignored.
* For each input key: * For each input key:
'input=filename<[variable_name]>' '<input_key>=<filename>' or
* The "[variable_name]" key is optional. Will be set to None if not '<input_key>=<filename>[<variable_name>]'
specified. * The optional 'variable_name' key will be set to None if not specified.
Returns: Returns:
A dictionary that maps input keys to a tuple of file name and varaible name. A dictionary that maps input keys to a tuple of file name and variable name.
Raises: Raises:
RuntimeError: An error when the given input is in a bad format. RuntimeError: An error when the given input string is in a bad format.
""" """
input_dict = {} input_dict = {}
inputs_raw = inputs_str.split(',') inputs_raw = inputs_str.split(';')
for input_raw in filter(bool, inputs_raw): # skip empty strings for input_raw in filter(bool, inputs_raw): # skip empty strings
# Remove quotes and whitespaces
input_raw = input_raw.replace('"', '').replace('\'', '').replace(' ', '')
# Format of input=filename[variable_name]' # Format of input=filename[variable_name]'
match = re.match(r'^([\w\-]+)=([\w\-.\/]+)\[([\w\-]+)\]$', input_raw) match = re.match(r'([^=]+)=([^\[\]]+)\[([^\[\]]+)\]$', input_raw)
if match: if match:
input_dict[match.group(1)] = (match.group(2), match.group(3)) input_dict[match.group(1)] = match.group(2), match.group(3)
else: else:
# Format of input=filename' # Format of input=filename'
match = re.match(r'^([\w\-]+)=([\w\-.\/]+)$', input_raw) match = re.match(r'([^=]+)=([^\[\]]+)$', input_raw)
if match: if match:
input_dict[match.group(1)] = (match.group(2), None) input_dict[match.group(1)] = match.group(2), None
else: else:
raise RuntimeError( raise RuntimeError(
'Input \"%s\" format is incorrect. Please follow \"--inputs ' '--inputs "%s" format is incorrect. Please follow'
'input_key=file_name[variable_name]\" or input_key=file_name' % '"<input_key>=<filename>", or'
input_raw) '"<input_key>=<filename>[<variable_name>]"' % input_raw)
return input_dict return input_dict
def load_inputs_from_input_arg_string(inputs_str): def preprocess_input_exprs_arg_string(input_exprs_str):
"""Parses input arg string and load inputs into a dictionary. """Parses input arg into dictionary that maps input key to python expression.
Parses input string in the format of, for example, Parses input string in the format of 'input_key=<python expression>' into a
"input1=filename1[variable_name1],input2=filename2" into a dictionary that maps each input_key to its python expression.
dictionary looks like
{'input1:0': ndarray_saved_as_variable_name1_in_filename1 , Args:
'input2:0': ndarray_saved_in_filename2} input_exprs_str: A string that specifies python expression for input keys.
, which maps input keys to a numpy ndarray loaded from file. See Args section Each input is separated by semicolon. For each input key:
for more details on inputs format. 'input_key=<python expression>'
Returns:
A dictionary that maps input keys to python expressions.
Raises:
RuntimeError: An error when the given input string is in a bad format.
"""
input_dict = {}
for input_raw in filter(bool, input_exprs_str.split(';')):
if '=' not in input_exprs_str:
raise RuntimeError('--input_exprs "%s" format is incorrect. Please follow'
'"<input_key>=<python expression>"' % input_exprs_str)
input_key, expr = input_raw.split('=')
input_dict[input_key] = expr
return input_dict
def load_inputs_from_input_arg_string(inputs_str, input_exprs_str):
"""Parses input arg strings and create inputs feed_dict.
Parses '--inputs' string for inputs to be loaded from file, and parses
'--input_exprs' string for inputs to be evaluated from python expression.
Args: Args:
inputs_str: A string that specified where to load inputs. Each input is inputs_str: A string that specified where to load inputs. Each input is
separated by comma. separated by semicolon.
* If the command line arg for inputs is quoted and contains
whitespace(s), all whitespaces will be ignored.
* For each input key: * For each input key:
'input=filename[variable_name]' '<input_key>=<filename>' or
'<input_key>=<filename>[<variable_name>]'
* The optional 'variable_name' key will be set to None if not specified.
* File specified by 'filename' will be loaded using numpy.load. Inputs * File specified by 'filename' will be loaded using numpy.load. Inputs
can be loaded from only .npy, .npz or pickle files. can be loaded from only .npy, .npz or pickle files.
* The "[variable_name]" key is optional depending on the input file type * The "[variable_name]" key is optional depending on the input file type
as descripted in more details below. as descripted in more details below.
When loading from a npy file, which always contains a numpy ndarray, the When loading from a npy file, which always contains a numpy ndarray, the
content will be directly assigned to the specified input tensor. If a content will be directly assigned to the specified input tensor. If a
varaible_name is specified, it will be ignored and a warning will be variable_name is specified, it will be ignored and a warning will be
issued. issued.
When loading from a npz zip file, user can specify which variable within When loading from a npz zip file, user can specify which variable within
the zip file to load for the input tensor inside the square brackets. If the zip file to load for the input tensor inside the square brackets. If
@ -453,10 +475,12 @@ def load_inputs_from_input_arg_string(inputs_str):
to the specified input tensor, else SavedModel CLI will assume a to the specified input tensor, else SavedModel CLI will assume a
dictionary is stored in the pickle file and the value corresponding to dictionary is stored in the pickle file and the value corresponding to
the variable_name will be used. the variable_name will be used.
input_exprs_str: A string that specified python expressions for inputs.
* In the format of: '<input_key>=<python expression>'.
* numpy module is available as np.
Returns: Returns:
A dictionary that maps input tensor keys to a numpy ndarray loaded from A dictionary that maps input tensor keys to numpy ndarrays.
file.
Raises: Raises:
RuntimeError: An error when a key is specified, but the input file contains RuntimeError: An error when a key is specified, but the input file contains
@ -466,13 +490,14 @@ def load_inputs_from_input_arg_string(inputs_str):
""" """
tensor_key_feed_dict = {} tensor_key_feed_dict = {}
for input_tensor_key, ( inputs = preprocess_inputs_arg_string(inputs_str)
filename, input_exprs = preprocess_input_exprs_arg_string(input_exprs_str)
variable_name) in preprocess_input_arg_string(inputs_str).items():
for input_tensor_key, (filename, variable_name) in inputs.items():
data = np.load(filename)
# When a variable_name key is specified for the input file # When a variable_name key is specified for the input file
if variable_name: if variable_name:
data = np.load(filename)
# if file contains a single ndarray, ignore the input name # if file contains a single ndarray, ignore the input name
if isinstance(data, np.ndarray): if isinstance(data, np.ndarray):
warnings.warn( warnings.warn(
@ -488,7 +513,6 @@ def load_inputs_from_input_arg_string(inputs_str):
(filename, variable_name)) (filename, variable_name))
# When no key is specified for the input file. # When no key is specified for the input file.
else: else:
data = np.load(filename)
# Check if npz file only contains a single numpy ndarray. # Check if npz file only contains a single numpy ndarray.
if isinstance(data, np.lib.npyio.NpzFile): if isinstance(data, np.lib.npyio.NpzFile):
variable_name_list = data.files variable_name_list = data.files
@ -500,6 +524,16 @@ def load_inputs_from_input_arg_string(inputs_str):
else: else:
tensor_key_feed_dict[input_tensor_key] = data tensor_key_feed_dict[input_tensor_key] = data
# When input is a python expression:
for input_tensor_key, py_expr in input_exprs.items():
if input_tensor_key in tensor_key_feed_dict:
warnings.warn(
'input_key %s has been specified with both --inputs and --input_exprs'
' options. Value in --input_exprs will be used.' % input_tensor_key)
# ast.literal_eval does not work with numpy expressions
tensor_key_feed_dict[input_tensor_key] = eval(py_expr) # pylint: disable=eval-used
return tensor_key_feed_dict return tensor_key_feed_dict
@ -531,7 +565,8 @@ def run(args):
Args: Args:
args: A namespace parsed from command line. args: A namespace parsed from command line.
""" """
tensor_key_feed_dict = load_inputs_from_input_arg_string(args.inputs) tensor_key_feed_dict = load_inputs_from_input_arg_string(
args.inputs, args.input_exprs)
run_saved_model_with_feed_dict(args.dir, args.tag_set, args.signature_def, run_saved_model_with_feed_dict(args.dir, args.tag_set, args.signature_def,
tensor_key_feed_dict, args.outdir, tensor_key_feed_dict, args.outdir,
args.overwrite, tf_debug=args.tf_debug) args.overwrite, tf_debug=args.tf_debug)
@ -559,7 +594,7 @@ def create_parser():
'MetaGraphDef specified by its tag-set:\n' 'MetaGraphDef specified by its tag-set:\n'
'$saved_model_cli show --dir /tmp/saved_model --tag_set serve\n' '$saved_model_cli show --dir /tmp/saved_model --tag_set serve\n'
'For a MetaGraphDef with multiple tags in the tag-set, all tags must be ' 'For a MetaGraphDef with multiple tags in the tag-set, all tags must be '
'passed in, separated by \',\':\n' 'passed in, separated by \';\':\n'
'$saved_model_cli show --dir /tmp/saved_model --tag_set serve,gpu\n\n' '$saved_model_cli show --dir /tmp/saved_model --tag_set serve,gpu\n\n'
'To show all inputs and outputs TensorInfo for a specific' 'To show all inputs and outputs TensorInfo for a specific'
' SignatureDef specified by the SignatureDef key in a' ' SignatureDef specified by the SignatureDef key in a'
@ -601,7 +636,7 @@ def create_parser():
'$saved_model_cli show --dir /tmp/saved_model --tag_set serve' '$saved_model_cli show --dir /tmp/saved_model --tag_set serve'
'--signature_def serving_default ' '--signature_def serving_default '
'--inputs input1_key=/tmp/124.npz[x],input2_key=/tmp/123.npy' '--inputs input1_key=/tmp/124.npz[x],input2_key=/tmp/123.npy'
'--outdir=/out\n\n' '--input_exprs \'input3_key=np.ones(2)\' --outdir=/out\n\n'
'For more information about input file format, please see:\n' 'For more information about input file format, please see:\n'
'https://www.tensorflow.org/programmers_guide/saved_model_cli\n') 'https://www.tensorflow.org/programmers_guide/saved_model_cli\n')
parser_run = subparsers.add_parser( parser_run = subparsers.add_parser(
@ -622,10 +657,15 @@ def create_parser():
required=True, required=True,
metavar='SIGNATURE_DEF_KEY', metavar='SIGNATURE_DEF_KEY',
help='key of SignatureDef to run') help='key of SignatureDef to run')
msg = ('inputs in the format of \'input_key=filename[variable_name]\', ' msg = ('Loading inputs from files, in the format of \'<input_key>=<filename>,'
'separated by \',\'. Inputs can only be loaded from .npy, .npz or ' ' or \'<input_key>=<filename>[<variable_name>]\', separated by \';\'.'
'pickle files. Please use input keys instead of input names.') ' The file format can only be from .npy, .npz or pickle.')
parser_run.add_argument('--inputs', type=str, required=True, help=msg) parser_run.add_argument('--inputs', type=str, default='', help=msg)
msg = ('Specifying inputs by python expressions, in the format of'
' "<input_key>=\'<python expression>\'", separated by \';\'. '
'numpy module is available as \'np\'. '
'Will override duplicate input_keys from --inputs option.')
parser_run.add_argument('--input_exprs', type=str, default='', help=msg)
parser_run.add_argument( parser_run.add_argument(
'--outdir', '--outdir',
type=str, type=str,
@ -649,6 +689,8 @@ def create_parser():
def main(): def main():
parser = create_parser() parser = create_parser()
args = parser.parse_args() args = parser.parse_args()
if not args.inputs and not args.input_exprs:
args.error('At least one of --inputs and --input_exprs is required')
args.func(args) args.func(args)

View File

@ -201,28 +201,37 @@ Method name is: tensorflow/serving/predict"""
self.assertEqual(err.getvalue().strip(), '') self.assertEqual(err.getvalue().strip(), '')
def testInputPreProcessFormats(self): def testInputPreProcessFormats(self):
input_str = 'input1=/path/file.txt[ab3], input2=file2,,' input_str = 'input1=/path/file.txt[ab3];input2=file2'
input_dict = saved_model_cli.preprocess_input_arg_string(input_str) input_expr_str = 'input3=np.zeros([2,2]);input4=[4,5]'
input_dict = saved_model_cli.preprocess_inputs_arg_string(input_str)
input_expr_dict = saved_model_cli.preprocess_input_exprs_arg_string(
input_expr_str)
self.assertTrue(input_dict['input1'] == ('/path/file.txt', 'ab3')) self.assertTrue(input_dict['input1'] == ('/path/file.txt', 'ab3'))
self.assertTrue(input_dict['input2'] == ('file2', None)) self.assertTrue(input_dict['input2'] == ('file2', None))
self.assertTrue(input_expr_dict['input3'] == 'np.zeros([2,2])')
def testInputPreProcessQuoteAndWhitespace(self): self.assertTrue(input_expr_dict['input4'] == '[4,5]')
input_str = '\' input1 = file[v_1]\', input2=file ["sd"] '
input_dict = saved_model_cli.preprocess_input_arg_string(input_str)
self.assertTrue(input_dict['input1'] == ('file', 'v_1'))
self.assertTrue(input_dict['input2'] == ('file', 'sd'))
self.assertTrue(len(input_dict) == 2) self.assertTrue(len(input_dict) == 2)
self.assertTrue(len(input_expr_dict) == 2)
def testInputPreProcessFileNames(self):
input_str = (r'inputx=C:\Program Files\data.npz[v:0];'
r'input:0=c:\PROGRA~1\data.npy')
input_dict = saved_model_cli.preprocess_inputs_arg_string(input_str)
print(input_dict)
self.assertTrue(input_dict['inputx'] == (r'C:\Program Files\data.npz',
'v:0'))
self.assertTrue(input_dict['input:0'] == (r'c:\PROGRA~1\data.npy', None))
def testInputPreProcessErrorBadFormat(self): def testInputPreProcessErrorBadFormat(self):
input_str = 'inputx=file[[v1]v2' input_str = 'inputx=file[[v1]v2'
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
saved_model_cli.preprocess_input_arg_string(input_str) saved_model_cli.preprocess_inputs_arg_string(input_str)
input_str = 'inputx:file' input_str = 'inputx:file'
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
saved_model_cli.preprocess_input_arg_string(input_str) saved_model_cli.preprocess_inputs_arg_string(input_str)
input_str = 'inputx=file(v_1)' input_str = 'inputx:np.zeros((5))'
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
saved_model_cli.preprocess_input_arg_string(input_str) saved_model_cli.preprocess_input_exprs_arg_string(input_str)
def testInputParserNPY(self): def testInputParserNPY(self):
x0 = np.array([[1], [2]]) x0 = np.array([[1], [2]])
@ -231,8 +240,8 @@ Method name is: tensorflow/serving/predict"""
input1_path = os.path.join(test.get_temp_dir(), 'input1.npy') input1_path = os.path.join(test.get_temp_dir(), 'input1.npy')
np.save(input0_path, x0) np.save(input0_path, x0)
np.save(input1_path, x1) np.save(input1_path, x1)
input_str = 'x0=' + input0_path + '[x0],x1=' + input1_path input_str = 'x0=' + input0_path + '[x0];x1=' + input1_path
feed_dict = saved_model_cli.load_inputs_from_input_arg_string(input_str) feed_dict = saved_model_cli.load_inputs_from_input_arg_string(input_str, '')
self.assertTrue(np.all(feed_dict['x0'] == x0)) self.assertTrue(np.all(feed_dict['x0'] == x0))
self.assertTrue(np.all(feed_dict['x1'] == x1)) self.assertTrue(np.all(feed_dict['x1'] == x1))
@ -240,8 +249,8 @@ Method name is: tensorflow/serving/predict"""
x0 = np.array([[1], [2]]) x0 = np.array([[1], [2]])
input_path = os.path.join(test.get_temp_dir(), 'input.npz') input_path = os.path.join(test.get_temp_dir(), 'input.npz')
np.savez(input_path, a=x0) np.savez(input_path, a=x0)
input_str = 'x=' + input_path + '[a],y=' + input_path input_str = 'x=' + input_path + '[a];y=' + input_path
feed_dict = saved_model_cli.load_inputs_from_input_arg_string(input_str) feed_dict = saved_model_cli.load_inputs_from_input_arg_string(input_str, '')
self.assertTrue(np.all(feed_dict['x'] == x0)) self.assertTrue(np.all(feed_dict['x'] == x0))
self.assertTrue(np.all(feed_dict['y'] == x0)) self.assertTrue(np.all(feed_dict['y'] == x0))
@ -258,25 +267,50 @@ Method name is: tensorflow/serving/predict"""
pickle.dump(pkl1, f) pickle.dump(pkl1, f)
with open(input_path2, 'wb') as f: with open(input_path2, 'wb') as f:
pickle.dump(pkl2, f) pickle.dump(pkl2, f)
input_str = 'x=' + input_path0 + '[b],y=' + input_path1 + '[c],' input_str = 'x=' + input_path0 + '[b];y=' + input_path1 + '[c];'
input_str += 'z=' + input_path2 input_str += 'z=' + input_path2
feed_dict = saved_model_cli.load_inputs_from_input_arg_string(input_str) feed_dict = saved_model_cli.load_inputs_from_input_arg_string(input_str, '')
self.assertTrue(np.all(feed_dict['x'] == pkl0['b'])) self.assertTrue(np.all(feed_dict['x'] == pkl0['b']))
self.assertTrue(np.all(feed_dict['y'] == pkl1)) self.assertTrue(np.all(feed_dict['y'] == pkl1))
self.assertTrue(np.all(feed_dict['z'] == pkl2)) self.assertTrue(np.all(feed_dict['z'] == pkl2))
def testInputParserQuoteAndWhitespace(self): def testInputParserPythonExpression(self):
x1 = np.ones([2, 10])
x2 = np.array([[1], [2], [3]])
x3 = np.mgrid[0:5, 0:5]
x4 = [[3], [4]]
input_expr_str = ('x1=np.ones([2,10]);x2=np.array([[1],[2],[3]]);'
'x3=np.mgrid[0:5,0:5];x4=[[3],[4]]')
feed_dict = saved_model_cli.load_inputs_from_input_arg_string(
'', input_expr_str)
self.assertTrue(np.all(feed_dict['x1'] == x1))
self.assertTrue(np.all(feed_dict['x2'] == x2))
self.assertTrue(np.all(feed_dict['x3'] == x3))
self.assertTrue(np.all(feed_dict['x4'] == x4))
def testInputParserBoth(self):
x0 = np.array([[1], [2]]) x0 = np.array([[1], [2]])
x1 = np.array(range(6)).reshape(2, 3) input_path = os.path.join(test.get_temp_dir(), 'input.npz')
input0_path = os.path.join(test.get_temp_dir(), 'input0.npy') np.savez(input_path, a=x0)
input1_path = os.path.join(test.get_temp_dir(), 'input1.npy') x1 = np.ones([2, 10])
np.save(input0_path, x0) input_str = 'x0=' + input_path + '[a]'
np.save(input1_path, x1) input_expr_str = 'x1=np.ones([2,10])'
input_str = '"x0=' + input0_path + '[x0] , x1 = ' + input1_path + '"' feed_dict = saved_model_cli.load_inputs_from_input_arg_string(
feed_dict = saved_model_cli.load_inputs_from_input_arg_string(input_str) input_str, input_expr_str)
self.assertTrue(np.all(feed_dict['x0'] == x0)) self.assertTrue(np.all(feed_dict['x0'] == x0))
self.assertTrue(np.all(feed_dict['x1'] == x1)) self.assertTrue(np.all(feed_dict['x1'] == x1))
def testInputParserBothDuplicate(self):
x0 = np.array([[1], [2]])
input_path = os.path.join(test.get_temp_dir(), 'input.npz')
np.savez(input_path, a=x0)
x1 = np.ones([2, 10])
input_str = 'x0=' + input_path + '[a]'
input_expr_str = 'x0=np.ones([2,10])'
feed_dict = saved_model_cli.load_inputs_from_input_arg_string(
input_str, input_expr_str)
self.assertTrue(np.all(feed_dict['x0'] == x1))
def testInputParserErrorNoName(self): def testInputParserErrorNoName(self):
x0 = np.array([[1], [2]]) x0 = np.array([[1], [2]])
x1 = np.array(range(5)) x1 = np.array(range(5))
@ -284,7 +318,7 @@ Method name is: tensorflow/serving/predict"""
np.savez(input_path, a=x0, b=x1) np.savez(input_path, a=x0, b=x1)
input_str = 'x=' + input_path input_str = 'x=' + input_path
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
saved_model_cli.load_inputs_from_input_arg_string(input_str) saved_model_cli.load_inputs_from_input_arg_string(input_str, '')
def testInputParserErrorWrongName(self): def testInputParserErrorWrongName(self):
x0 = np.array([[1], [2]]) x0 = np.array([[1], [2]])
@ -293,7 +327,7 @@ Method name is: tensorflow/serving/predict"""
np.savez(input_path, a=x0, b=x1) np.savez(input_path, a=x0, b=x1)
input_str = 'x=' + input_path + '[c]' input_str = 'x=' + input_path + '[c]'
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
saved_model_cli.load_inputs_from_input_arg_string(input_str) saved_model_cli.load_inputs_from_input_arg_string(input_str, '')
def testRunCommandExistingOutdir(self): def testRunCommandExistingOutdir(self):
self.parser = saved_model_cli.create_parser() self.parser = saved_model_cli.create_parser()

View File

@ -994,7 +994,7 @@ class SVSummaryThread(coordinator.LooperThread):
summary_strs = self._sess.run(self._sv.summary_op) summary_strs = self._sess.run(self._sv.summary_op)
global_step = None global_step = None
if self._sv.summary_writer: if self._sv.summary_writer:
logging.info("Recording summary at step %d.", global_step) logging.info("Recording summary at step %s.", global_step)
self._sv.summary_writer.add_summary(summary_strs, global_step) self._sv.summary_writer.add_summary(summary_strs, global_step)

View File

@ -227,7 +227,7 @@ string ToString(CUresult result) {
// created by StreamExecutor (to ensure that the CUDA runtime didn't create a // created by StreamExecutor (to ensure that the CUDA runtime didn't create a
// context behind our backs). // context behind our backs).
CUcontext CurrentContext() { CUcontext CurrentContext() {
CUcontext current = CUDADriver::CurrentContextOrDie(); CUcontext current = CUDADriver::CurrentContextOrDie();
if (current != nullptr && !CreatedContexts::Has(current)) { if (current != nullptr && !CreatedContexts::Has(current)) {
LOG(FATAL) << "current context was not created by the StreamExecutor " LOG(FATAL) << "current context was not created by the StreamExecutor "
"cuda_driver API: " "cuda_driver API: "
@ -480,27 +480,56 @@ bool DeviceOptionsToContextFlags(DeviceOptions device_options, int *flags) {
CUdevice device, DeviceOptions device_options, CudaContext** context) { CUdevice device, DeviceOptions device_options, CudaContext** context) {
*context = nullptr; *context = nullptr;
CUcontext former_context = CurrentContext();
if (former_context != nullptr) {
LOG(WARNING) << "creating context when one is currently active; existing: "
<< former_context;
}
int flags = 0; int flags = 0;
if (!DeviceOptionsToContextFlags(device_options, &flags)) { if (!DeviceOptionsToContextFlags(device_options, &flags)) {
LOG(WARNING) << "could not convert all device options into context flags"; LOG(WARNING) << "could not convert all device options into context flags";
} }
CUresult res; CUresult res;
CUcontext former_context;
CUcontext new_context; CUcontext new_context;
{ {
// TODO(leary) Need to see if NVIDIA can expunge the leakiness in their // TODO(leary) Need to see if NVIDIA can expunge the leakiness in their
// context creation: see http://b/13248943 // context creation: see http://b/13248943
#if CUDA_VERSION >= 7000 #if CUDA_VERSION >= 7000
res = cuDevicePrimaryCtxSetFlags(device, flags); {
unsigned int former_primary_context_flags;
int former_primary_context_is_active;
CHECK_EQ(CUDA_SUCCESS,
cuDevicePrimaryCtxGetState(device, &former_primary_context_flags,
&former_primary_context_is_active));
if (former_primary_context_flags != flags) {
if (former_primary_context_is_active) {
LOG(ERROR)
<< "The primary context is active and has a different flag set ("
<< former_primary_context_flags << ") than the desired flag set ("
<< flags << ").";
} else {
CHECK_EQ(CUDA_SUCCESS, cuDevicePrimaryCtxSetFlags(device, flags));
}
}
}
former_context = CUDADriver::CurrentContextOrDie();
res = cuDevicePrimaryCtxRetain(&new_context, device); res = cuDevicePrimaryCtxRetain(&new_context, device);
if (former_context != nullptr) {
if (former_context == new_context) {
VLOG(2) << "The primary context " << former_context
<< " exists before initializing the StreamExecutor.";
} else {
LOG(WARNING) << "A non-primary context " << former_context
<< " exists before initializing the StreamExecutor. We "
"haven't verified StreamExecutor works with that.";
}
}
#else #else
former_context = CurrentContext();
if (former_context != nullptr) {
LOG(WARNING)
<< "creating context when one is currently active; existing: "
<< former_context;
}
res = cuCtxCreate(&new_context, flags, device); res = cuCtxCreate(&new_context, flags, device);
#endif #endif
} }

View File

@ -334,8 +334,8 @@ int Main(int argc, char** argv) {
Flag("show_memory", &show_memory, "whether to list stats by memory used"), Flag("show_memory", &show_memory, "whether to list stats by memory used"),
Flag("memory_limit", &memory_limit, Flag("memory_limit", &memory_limit,
"how many items to show by memory used"), "how many items to show by memory used"),
Flag("show_type", &show_time, "whether to list stats by op type"), Flag("show_type", &show_type, "whether to list stats by op type"),
Flag("show_summary", &show_time, Flag("show_summary", &show_summary,
"whether to show a summary of the stats"), "whether to show a summary of the stats"),
Flag("show_flops", &show_flops, "whether to estimate the model's FLOPs"), Flag("show_flops", &show_flops, "whether to estimate the model's FLOPs"),
Flag("warmup_runs", &warmup_runs, "how many runs to initialize model"), Flag("warmup_runs", &warmup_runs, "how many runs to initialize model"),

View File

@ -9,7 +9,7 @@ exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "py_test") load("//tensorflow:tensorflow.bzl", "py_test")
py_library( py_binary(
name = "grpc_tensorflow_server", name = "grpc_tensorflow_server",
srcs = [ srcs = [
"grpc_tensorflow_server.py", "grpc_tensorflow_server.py",

View File

@ -36,6 +36,7 @@ from __future__ import print_function
import argparse import argparse
import sys import sys
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import tensorflow_server_pb2 from tensorflow.core.protobuf import tensorflow_server_pb2
from tensorflow.python.platform import app from tensorflow.python.platform import app
from tensorflow.python.training import server_lib from tensorflow.python.training import server_lib
@ -103,8 +104,11 @@ def main(unused_args):
raise ValueError("Invalid task_id: %d" % FLAGS.task_id) raise ValueError("Invalid task_id: %d" % FLAGS.task_id)
server_def.task_index = FLAGS.task_id server_def.task_index = FLAGS.task_id
config = config_pb2.ConfigProto(gpu_options=config_pb2.GPUOptions(
per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction))
# Create GRPC Server instance # Create GRPC Server instance
server = server_lib.Server(server_def) server = server_lib.Server(server_def, config=config)
# join() is blocking, unlike start() # join() is blocking, unlike start()
server.join() server.join()
@ -137,6 +141,11 @@ if __name__ == "__main__":
default=0, default=0,
help="Task index, e.g., 0" help="Task index, e.g., 0"
) )
parser.add_argument(
"--gpu_memory_fraction",
type=float,
default=1.0,
help="Fraction of GPU memory allocated",)
parser.add_argument( parser.add_argument(
"--verbose", "--verbose",
type="bool", type="bool",
@ -145,5 +154,6 @@ if __name__ == "__main__":
default=False, default=False,
help="Verbose mode" help="Verbose mode"
) )
FLAGS, unparsed = parser.parse_known_args() FLAGS, unparsed = parser.parse_known_args()
app.run(main=main, argv=[sys.argv[0]] + unparsed) app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -7,10 +7,7 @@ exports_files(["LICENSE-2.0.txt"])
native.cc_library( native.cc_library(
name = "linear_solver_glop", name = "linear_solver_glop",
deps = [ deps = [
"@ortools_archive//linear_solver:linear_solver_glop", "@ortools_archive//linear_solver:linear_solver_glop",
], ],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )

View File

@ -1,4 +1,18 @@
#!/usr/bin/env bash #!/usr/bin/env bash
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
set -u # Check for undefined variables set -u # Check for undefined variables