Merge commit for internal changes
This commit is contained in:
commit
bbce813a58
@ -388,6 +388,16 @@ tf_gen_op_wrappers_cc(
|
||||
visibility = ["//tensorflow:internal"],
|
||||
)
|
||||
|
||||
tf_gen_op_wrappers_cc(
|
||||
name = "functional_ops",
|
||||
include_internal_ops = 1,
|
||||
op_lib_names = [
|
||||
"functional_ops",
|
||||
],
|
||||
pkg = "//tensorflow/core",
|
||||
visibility = ["//tensorflow:internal"],
|
||||
)
|
||||
|
||||
tf_gen_op_wrappers_cc(
|
||||
name = "resource_variable_ops",
|
||||
include_internal_ops = 1,
|
||||
|
@ -33,6 +33,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
@ -339,6 +340,14 @@ class LiteralUtil {
|
||||
const Layout& layout,
|
||||
Literal* literal);
|
||||
|
||||
// Populates literal values by calling the generator function for every cell
|
||||
// in the literal object.
|
||||
template <typename NativeT>
|
||||
static Status Populate(
|
||||
Literal* literal,
|
||||
const std::function<NativeT(tensorflow::gtl::ArraySlice<int64> indexes)>&
|
||||
generator);
|
||||
|
||||
// Creates a Literal of the given dimensions with all elements set to the
|
||||
// given value.
|
||||
template <typename NativeT>
|
||||
@ -992,6 +1001,43 @@ template <typename NativeT>
|
||||
literal);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
/* static */ Status LiteralUtil::Populate(
|
||||
Literal* literal,
|
||||
const std::function<NativeT(tensorflow::gtl::ArraySlice<int64> indexes)>&
|
||||
generator) {
|
||||
const Shape& shape = literal->shape();
|
||||
int64 rank = ShapeUtil::Rank(shape);
|
||||
TF_RET_CHECK(shape.element_type() ==
|
||||
primitive_util::NativeToPrimitiveType<NativeT>());
|
||||
tensorflow::protobuf::RepeatedField<NativeT>* data =
|
||||
GetMutableRepeatedField<NativeT>(literal);
|
||||
if (rank > 0) {
|
||||
std::vector<int64> base(rank, 0);
|
||||
std::vector<int64> step(rank, 1);
|
||||
std::vector<int64> minor_scan_indexes(rank, 0);
|
||||
int64 minor_dimension = shape.layout().minor_to_major()[0];
|
||||
int64 minor_dimension_size =
|
||||
ShapeUtil::GetDimension(shape, minor_dimension);
|
||||
|
||||
step[minor_dimension] = minor_dimension_size;
|
||||
auto init_function = [&](const std::vector<int64>& indexes) {
|
||||
int64 index = LinearIndex(*literal, indexes);
|
||||
std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin());
|
||||
for (int64 i = 0; i < minor_dimension_size; ++i) {
|
||||
minor_scan_indexes[minor_dimension] = i;
|
||||
data->Set(index + i, generator(minor_scan_indexes));
|
||||
}
|
||||
return true;
|
||||
};
|
||||
ShapeUtil::ForEachIndex(shape, base, AsInt64Slice(shape.dimensions()), step,
|
||||
init_function);
|
||||
} else {
|
||||
data->Set(0, generator({}));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
/* static */ void LiteralUtil::PopulateWithValue(
|
||||
NativeT value, tensorflow::gtl::ArraySlice<int64> dimensions,
|
||||
|
@ -422,7 +422,7 @@ class ReferenceUtil {
|
||||
static std::unique_ptr<Array2D<T1>> ApplyElementwise2D(
|
||||
F&& f, const Array2D<T1>& array1, const Array2D<Ts>&... arrays) {
|
||||
AssertSameSize2D(array1, arrays...);
|
||||
auto result = MakeUnique<Array2D<T1>>(array1.n1(), array1.n1());
|
||||
auto result = MakeUnique<Array2D<T1>>(array1.n1(), array1.n2());
|
||||
for (int64 i = 0; i < array1.n1(); ++i) {
|
||||
for (int64 j = 0; j < array1.n2(); ++j) {
|
||||
(*result)(i, j) = f(array1(i, j), arrays(i, j)...);
|
||||
|
@ -80,8 +80,6 @@ cc_library(
|
||||
":hlo_query",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
@ -666,8 +664,8 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
":buffer_liveness",
|
||||
":heap_simulator",
|
||||
":hlo",
|
||||
":hlo_ordering",
|
||||
":logical_buffer",
|
||||
":tuple_points_to_analysis",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
@ -707,51 +705,38 @@ cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "heap_simulator",
|
||||
srcs = [
|
||||
"heap_simulator.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"heap_simulator.h",
|
||||
],
|
||||
deps = [
|
||||
":hlo",
|
||||
":liveness_util",
|
||||
":logical_buffer",
|
||||
":tuple_points_to_analysis",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "heap_simulator_test",
|
||||
srcs = ["heap_simulator_test.cc"],
|
||||
deps = [
|
||||
":heap_simulator",
|
||||
":hlo",
|
||||
":hlo_ordering",
|
||||
":logical_buffer",
|
||||
":tuple_points_to_analysis",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
# The hlo_ordering library contains both hlo_ordering and heap_simulator because
|
||||
# they are mutually dependent.
|
||||
cc_library(
|
||||
name = "hlo_ordering",
|
||||
srcs = [
|
||||
"heap_simulator.cc",
|
||||
"hlo_ordering.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"heap_simulator.h",
|
||||
"hlo_ordering.h",
|
||||
],
|
||||
deps = [
|
||||
":call_graph",
|
||||
":heap_simulator",
|
||||
":hlo",
|
||||
":liveness_util",
|
||||
":logical_buffer",
|
||||
":tuple_points_to_analysis",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
@ -1436,6 +1421,7 @@ cc_test(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
|
@ -548,6 +548,8 @@ Status BufferAssigner::AssignBuffersForComputation(
|
||||
const FlatSet<const HloInstruction*>* hlos_to_allocate,
|
||||
const FlatSet<const LogicalBuffer*>& colocated_buffers,
|
||||
const FlatSet<BufferAllocation::Index>& colocated_allocations,
|
||||
FlatMap<const HloComputation*, FlatSet<const LogicalBuffer*>>*
|
||||
buffers_to_assign_sequentially,
|
||||
BufferAssignment* assignment) {
|
||||
// Buffers are sorted and assigned to BufferAllocations in decreasing order of
|
||||
// size.
|
||||
@ -578,9 +580,16 @@ Status BufferAssigner::AssignBuffersForComputation(
|
||||
// If there is a sequential instruction ordering, we'll delay assignment of
|
||||
// temp buffers until after the main assignment loop.
|
||||
const BufferLiveness& liveness = assignment->liveness();
|
||||
const std::vector<const HloInstruction*>* sequential_order =
|
||||
liveness.hlo_ordering().SequentialOrder(*computation);
|
||||
FlatSet<const LogicalBuffer*> unassigned_temp_buffers;
|
||||
const bool has_sequential_order =
|
||||
liveness.hlo_ordering().SequentialOrder(*computation) != nullptr;
|
||||
if (has_sequential_order && buffers_to_assign_sequentially != nullptr) {
|
||||
// Every sequential computation must get an entry in the
|
||||
// buffers_to_assign_sequentially map, even if we end up with an empty set
|
||||
// of buffers. This ensures we can correctly determine whether to run
|
||||
// whole-module heap simulation.
|
||||
buffers_to_assign_sequentially->emplace(computation,
|
||||
FlatSet<const LogicalBuffer*>());
|
||||
}
|
||||
|
||||
// Sort the LogicalBuffers first by size. We assign the larger LogicalBuffers
|
||||
// first for simplicity. This means any previously created BufferAllocation is
|
||||
@ -599,7 +608,7 @@ Status BufferAssigner::AssignBuffersForComputation(
|
||||
// important reuse case where an elementwise instruction reuses one of its
|
||||
// operand's buffer. This improves locality.
|
||||
std::sort(sorted_buffers.begin(), sorted_buffers.end(),
|
||||
[this, sequential_order, &liveness, &post_order_position](
|
||||
[this, has_sequential_order, &liveness, &post_order_position](
|
||||
const LogicalBuffer* a, const LogicalBuffer* b) {
|
||||
// Primary sort is by decreasing buffer size.
|
||||
const int64 a_size = buffer_size_(*a);
|
||||
@ -609,7 +618,7 @@ Status BufferAssigner::AssignBuffersForComputation(
|
||||
}
|
||||
// Otherwise live out buffers come before others, if the
|
||||
// instructions are sequentially ordered.
|
||||
if (sequential_order != nullptr) {
|
||||
if (has_sequential_order) {
|
||||
const bool a_live_out = liveness.MaybeLiveOut(*a);
|
||||
const bool b_live_out = liveness.MaybeLiveOut(*b);
|
||||
if (a_live_out != b_live_out) {
|
||||
@ -746,7 +755,7 @@ Status BufferAssigner::AssignBuffersForComputation(
|
||||
}
|
||||
}
|
||||
|
||||
if (!assignment->HasAllocation(*buffer) && sequential_order != nullptr &&
|
||||
if (!assignment->HasAllocation(*buffer) && has_sequential_order &&
|
||||
!liveness.MaybeLiveOut(*buffer)) {
|
||||
// There is a sequential instruction ordering, so we delay assignment of
|
||||
// temp buffers until after the loop. We do this right before we decide to
|
||||
@ -758,7 +767,7 @@ Status BufferAssigner::AssignBuffersForComputation(
|
||||
// for the definition of temp buffers.
|
||||
CHECK(!is_entry_parameter) << *buffer;
|
||||
CHECK(!is_thread_local) << *buffer;
|
||||
unassigned_temp_buffers.insert(buffer);
|
||||
(*buffers_to_assign_sequentially)[computation].insert(buffer);
|
||||
VLOG(3) << "Delaying assignment of temp buffer: " << *buffer;
|
||||
continue;
|
||||
}
|
||||
@ -772,27 +781,68 @@ Status BufferAssigner::AssignBuffersForComputation(
|
||||
}
|
||||
}
|
||||
|
||||
if (!unassigned_temp_buffers.empty()) {
|
||||
TF_RETURN_IF_ERROR(AssignBuffersWithSequentialOrdering(
|
||||
*sequential_order, unassigned_temp_buffers, *computation, assignment));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BufferAssigner::AssignBuffersWithSequentialOrdering(
|
||||
const std::vector<const HloInstruction*>& sequence,
|
||||
const FlatSet<const LogicalBuffer*>& buffers_to_assign,
|
||||
const HloComputation& computation, BufferAssignment* assignment) {
|
||||
const FlatMap<const HloComputation*, FlatSet<const LogicalBuffer*>>&
|
||||
buffers_to_assign_sequentially,
|
||||
bool run_whole_module_heap_simulation, BufferAssignment* assignment) {
|
||||
// Run the sequence of instructions through the heap simulator. The heuristic
|
||||
// that seems to give the best results is lazy-best-fit, with all runs of
|
||||
// alloc / free calls sorted in decreasing size order.
|
||||
const HloOrdering& hlo_ordering = assignment->liveness().hlo_ordering();
|
||||
if (run_whole_module_heap_simulation) {
|
||||
// Run the heap simulation over the whole module. This reduces memory usage,
|
||||
// since buffers for kCall and kWhile sub-computations are only live for the
|
||||
// duration of their calling instructions.
|
||||
VLOG(1) << "Running whole-module heap simulation";
|
||||
SequentialHloOrdering::HloModuleSequence module_sequence;
|
||||
FlatSet<const LogicalBuffer*> all_buffers_to_assign;
|
||||
for (const auto& pair : buffers_to_assign_sequentially) {
|
||||
const HloComputation* computation = pair.first;
|
||||
const FlatSet<const LogicalBuffer*>& buffers_to_assign = pair.second;
|
||||
const std::vector<const HloInstruction*>* instruction_sequence =
|
||||
hlo_ordering.SequentialOrder(*computation);
|
||||
CHECK(instruction_sequence != nullptr) << computation->name();
|
||||
module_sequence[computation] = *instruction_sequence;
|
||||
all_buffers_to_assign.insert(buffers_to_assign.begin(),
|
||||
buffers_to_assign.end());
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HeapSimulator::Result result,
|
||||
const HeapSimulator::Result result,
|
||||
HeapSimulator::Run(MakeUnique<DecreasingSizeRunsHeap>(
|
||||
MakeUnique<LazyBestFitHeap>(alignment_)),
|
||||
sequence, computation,
|
||||
assignment->module(), module_sequence,
|
||||
assignment->points_to_analysis(), buffer_size_,
|
||||
&all_buffers_to_assign));
|
||||
AssignBuffersFromHeapSimulator(result, assignment);
|
||||
} else {
|
||||
// Run the heap-simulation on a per-computation basis. Buffers for
|
||||
// sub-computations are assigned disjoint BufferAllocations, assuming the
|
||||
// worst-case that they may all be live concurrently.
|
||||
VLOG(1) << "Running per-computation heap simulation";
|
||||
for (const auto& pair : buffers_to_assign_sequentially) {
|
||||
const HloComputation* computation = pair.first;
|
||||
const FlatSet<const LogicalBuffer*>& buffers_to_assign = pair.second;
|
||||
const std::vector<const HloInstruction*>* instruction_sequence =
|
||||
hlo_ordering.SequentialOrder(*computation);
|
||||
CHECK(instruction_sequence != nullptr) << computation->name();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
const HeapSimulator::Result result,
|
||||
HeapSimulator::Run(MakeUnique<DecreasingSizeRunsHeap>(
|
||||
MakeUnique<LazyBestFitHeap>(alignment_)),
|
||||
*computation, *instruction_sequence,
|
||||
assignment->points_to_analysis(), buffer_size_,
|
||||
&buffers_to_assign));
|
||||
AssignBuffersFromHeapSimulator(result, assignment);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void BufferAssigner::AssignBuffersFromHeapSimulator(
|
||||
const HeapSimulator::Result& result, BufferAssignment* assignment) {
|
||||
if (assignment->stats_.preallocated_temp_fragmentation_bytes == -1) {
|
||||
assignment->stats_.preallocated_temp_fragmentation_bytes =
|
||||
result.fragmentation_size;
|
||||
@ -801,8 +851,6 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
|
||||
result.fragmentation_size;
|
||||
}
|
||||
|
||||
// Use the results of the heap simulator to create one allocation per
|
||||
// computation, with LogicalBuffers packed to specific offsets.
|
||||
BufferAllocation* allocation = assignment->NewEmptyAllocation(
|
||||
result.heap_size, /*is_thread_local=*/false, /*is_reusable=*/true);
|
||||
for (const auto& buffer_chunk : result.chunk_map) {
|
||||
@ -810,7 +858,6 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
|
||||
const HeapSimulator::Chunk& chunk = buffer_chunk.second;
|
||||
assignment->AddAssignment(allocation, buffer, chunk.offset, chunk.size);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Adds the 'colocated_set' of buffers to 'colocated_buffer_sets', maintaining
|
||||
@ -1108,8 +1155,6 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<BufferLiveness> liveness,
|
||||
BufferLiveness::Run(module, std::move(hlo_ordering)));
|
||||
|
||||
std::vector<const HloComputation*> thread_local_computations;
|
||||
std::vector<const HloComputation*> global_computations;
|
||||
VLOG(1) << "Assigning buffers to module " << module->name();
|
||||
if (hlos_to_allocate != nullptr) {
|
||||
VLOG(3) << "LogicalBuffer assignment restricted to hlos: ";
|
||||
@ -1121,9 +1166,6 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
|
||||
XLA_VLOG_LINES(3, liveness->ToString());
|
||||
XLA_VLOG_LINES(3, liveness->points_to_analysis().ToString());
|
||||
|
||||
TF_RETURN_IF_ERROR(GatherComputationsByAllocationType(
|
||||
module, &thread_local_computations, &global_computations));
|
||||
|
||||
// Set of HLO's to allocate if hlos_to_allocate is given. Passed as a set to
|
||||
// AssignBuffersForComputation for fast membership testing.
|
||||
std::unique_ptr<FlatSet<const HloInstruction*>> hlo_set;
|
||||
@ -1148,16 +1190,38 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
|
||||
AssignColocatedBufferSets(colocated_buffer_sets, assignment.get(),
|
||||
&colocated_buffers, &colocated_allocations);
|
||||
|
||||
std::vector<const HloComputation*> thread_local_computations;
|
||||
std::vector<const HloComputation*> global_computations;
|
||||
TF_RETURN_IF_ERROR(GatherComputationsByAllocationType(
|
||||
module, &thread_local_computations, &global_computations));
|
||||
|
||||
// First assign buffers for global computatations. Temporary buffers for
|
||||
// sequential computations are collected in 'buffers_to_assign_sequentially'.
|
||||
FlatMap<const HloComputation*, FlatSet<const LogicalBuffer*>>
|
||||
buffers_to_assign_sequentially;
|
||||
for (auto* computation : global_computations) {
|
||||
TF_RETURN_IF_ERROR(AssignBuffersForComputation(
|
||||
computation, /*is_thread_local=*/false, hlo_set.get(),
|
||||
colocated_buffers, colocated_allocations, assignment.get()));
|
||||
colocated_buffers, colocated_allocations,
|
||||
&buffers_to_assign_sequentially, assignment.get()));
|
||||
}
|
||||
// Assign buffers with sequential ordering, if any. If all global computations
|
||||
// are sequential, we can run heap simuation on the whole module, which
|
||||
// reduces memory usage.
|
||||
const bool run_whole_module_heap_simulation =
|
||||
buffers_to_assign_sequentially.size() == global_computations.size();
|
||||
TF_RETURN_IF_ERROR(AssignBuffersWithSequentialOrdering(
|
||||
buffers_to_assign_sequentially, run_whole_module_heap_simulation,
|
||||
assignment.get()));
|
||||
|
||||
// Now assign buffers for thread-local computations. All LogicalBuffers get
|
||||
// their own BufferAllocation.
|
||||
for (auto* computation : thread_local_computations) {
|
||||
TF_RET_CHECK(computation != module->entry_computation());
|
||||
TF_RETURN_IF_ERROR(AssignBuffersForComputation(
|
||||
computation, /*is_thread_local=*/true, hlo_set.get(), colocated_buffers,
|
||||
colocated_allocations, assignment.get()));
|
||||
colocated_allocations, /*buffers_to_assign_sequentially=*/nullptr,
|
||||
assignment.get()));
|
||||
}
|
||||
|
||||
// Mark all buffers which may be live out of the entry computation as
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
|
||||
#include "tensorflow/compiler/xla/service/heap_simulator.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
@ -354,6 +355,9 @@ class BufferAssignment {
|
||||
void AddAssignment(BufferAllocation* allocation, const LogicalBuffer& buffer,
|
||||
int64 offset, int64 size);
|
||||
|
||||
// Returns the HloModule used to construct this assignment.
|
||||
const HloModule& module() { return *module_; }
|
||||
|
||||
// Returns the BufferLiveness object used to construct this assignment.
|
||||
const BufferLiveness& liveness() { return *liveness_; }
|
||||
|
||||
@ -427,14 +431,27 @@ class BufferAssigner {
|
||||
const tensorflow::gtl::FlatSet<const LogicalBuffer*>& colocated_buffers,
|
||||
const tensorflow::gtl::FlatSet<BufferAllocation::Index>&
|
||||
colocated_allocations,
|
||||
tensorflow::gtl::FlatMap<const HloComputation*,
|
||||
tensorflow::gtl::FlatSet<const LogicalBuffer*>>*
|
||||
buffers_to_assign_sequentially,
|
||||
BufferAssignment* assignment);
|
||||
|
||||
// Assigns 'buffers_to_assign' assuming the HLO instructions will be executed
|
||||
// in the given 'sequential_order'.
|
||||
// Assigns 'buffers_to_assign_sequentially' using heap simulation, assuming
|
||||
// the HLO instructions will be executed in the sequential order given by
|
||||
// assignment->liveness().hlo_ordering().SequentialOrder. If
|
||||
// 'run_whole_module_heap_simulation' is true, the heap simulation will be run
|
||||
// assuming all global computations are sequentially ordered.
|
||||
Status AssignBuffersWithSequentialOrdering(
|
||||
const std::vector<const HloInstruction*>& sequential_order,
|
||||
const tensorflow::gtl::FlatSet<const LogicalBuffer*>& buffers_to_assign,
|
||||
const HloComputation& computation, BufferAssignment* assignment);
|
||||
const tensorflow::gtl::FlatMap<
|
||||
const HloComputation*,
|
||||
tensorflow::gtl::FlatSet<const LogicalBuffer*>>&
|
||||
buffers_to_assign_sequentially,
|
||||
bool run_whole_module_heap_simulation, BufferAssignment* assignment);
|
||||
|
||||
// Uses the results of the heap simulator to create a single allocation, with
|
||||
// LogicalBuffers packed to specific offsets.
|
||||
void AssignBuffersFromHeapSimulator(const HeapSimulator::Result& result,
|
||||
BufferAssignment* assignment);
|
||||
|
||||
// Tries to assign the given instruction to the given buffer. Returns if the
|
||||
// assignment was successful.
|
||||
@ -477,8 +494,6 @@ class BufferAssigner {
|
||||
const HloComputation& computation, const BufferLiveness& buffer_liveness,
|
||||
std::vector<ColocatedBufferSet>* colocated_buffer_sets);
|
||||
|
||||
const HloModule* module_;
|
||||
|
||||
// Function which returns the buffer size for a given logical buffer (shape).
|
||||
LogicalBuffer::SizeFunction buffer_size_;
|
||||
|
||||
|
@ -24,6 +24,11 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
|
||||
int64 operand_index) {
|
||||
HloInstruction* producer = consumer->mutable_operand(operand_index);
|
||||
|
||||
// Output fusion is not currently supported on CPUs.
|
||||
if (producer->opcode() == HloOpcode::kFusion) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Condition for consumer: must be elementwise or a fusion op
|
||||
// (which necessarily only contains elementwise operations)
|
||||
if (!(consumer->opcode() == HloOpcode::kFusion ||
|
||||
|
@ -46,6 +46,11 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
|
||||
int64 operand_index) {
|
||||
HloInstruction* producer = consumer->mutable_operand(operand_index);
|
||||
|
||||
// Output fusion is not currently supported on GPUs.
|
||||
if (producer->opcode() == HloOpcode::kFusion) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// RNG operations are not currently parallel-friendly on GPU.
|
||||
if (producer->opcode() == HloOpcode::kRng) {
|
||||
return false;
|
||||
|
@ -53,12 +53,44 @@ std::vector<const LogicalBuffer*> UniqueOperandSourceBuffers(
|
||||
|
||||
/*static*/
|
||||
StatusOr<HeapSimulator::Result> HeapSimulator::Run(
|
||||
std::unique_ptr<HeapAlgorithm> algorithm,
|
||||
const std::vector<const HloInstruction*>& instruction_sequence,
|
||||
const HloComputation& computation,
|
||||
std::unique_ptr<HeapAlgorithm> algorithm, const HloModule& module,
|
||||
const SequentialHloOrdering::HloModuleSequence& module_sequence,
|
||||
const TuplePointsToAnalysis& points_to_analysis,
|
||||
const LogicalBuffer::SizeFunction& size_fn,
|
||||
const FlatSet<const LogicalBuffer*>* buffers_to_assign) {
|
||||
HeapSimulator heap(std::move(algorithm), size_fn, buffers_to_assign);
|
||||
const HloComputation* entry_computation = module.entry_computation();
|
||||
const std::vector<const HloInstruction*>& instruction_sequence =
|
||||
FindOrDie(module_sequence, entry_computation);
|
||||
TF_RETURN_IF_ERROR(heap.RunComputation(*entry_computation,
|
||||
instruction_sequence,
|
||||
points_to_analysis, &module_sequence));
|
||||
return heap.Finish();
|
||||
}
|
||||
|
||||
/*static*/
|
||||
StatusOr<HeapSimulator::Result> HeapSimulator::Run(
|
||||
std::unique_ptr<HeapAlgorithm> algorithm, const HloComputation& computation,
|
||||
const std::vector<const HloInstruction*>& instruction_sequence,
|
||||
const TuplePointsToAnalysis& points_to_analysis,
|
||||
const LogicalBuffer::SizeFunction& size_fn,
|
||||
const FlatSet<const LogicalBuffer*>* buffers_to_assign) {
|
||||
HeapSimulator heap(std::move(algorithm), size_fn, buffers_to_assign);
|
||||
TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence,
|
||||
points_to_analysis,
|
||||
/*module_sequence=*/nullptr));
|
||||
return heap.Finish();
|
||||
}
|
||||
|
||||
// Runs a heap simulation for the given 'computation', assuming the given
|
||||
// 'instruction_sequence'. If 'module_sequence' is non-null, it is used to find
|
||||
// kCall and kWhile sub-computations, and the heap simulation for those
|
||||
// sub-computations will be run recursively.
|
||||
Status HeapSimulator::RunComputation(
|
||||
const HloComputation& computation,
|
||||
const std::vector<const HloInstruction*>& instruction_sequence,
|
||||
const TuplePointsToAnalysis& points_to_analysis,
|
||||
const SequentialHloOrdering::HloModuleSequence* module_sequence) {
|
||||
// The goal here is to minimize memory usage, assuming the given sequential
|
||||
// ordering of instructions. The strategy is to walk through the instruction
|
||||
// sequence, calling Alloc and Free on the underlying heap algorithm. The
|
||||
@ -67,7 +99,6 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run(
|
||||
// 'live_buffers' tracks the liveness of each buffer that we assign, by
|
||||
// associating it with a set of HloInstructions that need to be visited. When
|
||||
// the set becomes empty, the buffer is no longer used, and can be freed.
|
||||
HeapSimulator heap(std::move(algorithm), size_fn, buffers_to_assign);
|
||||
FlatMap<const LogicalBuffer*, FlatSet<const HloInstruction*>> live_buffers;
|
||||
|
||||
const HloInstruction* root = computation.root_instruction();
|
||||
@ -90,7 +121,7 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run(
|
||||
// lifetime of buffers that aren't already connected by a data dependency.
|
||||
std::vector<const LogicalBuffer*> dead_buffers_to_free;
|
||||
for (const LogicalBuffer* buffer : buffers_defined_by_instruction) {
|
||||
if (heap.IgnoreBuffer(buffer)) {
|
||||
if (IgnoreBuffer(buffer)) {
|
||||
continue;
|
||||
}
|
||||
for (const BufferAlias& alias :
|
||||
@ -127,7 +158,7 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run(
|
||||
std::vector<const LogicalBuffer*> operand_buffers_to_free;
|
||||
for (const LogicalBuffer* operand_buffer :
|
||||
UniqueOperandSourceBuffers(instruction, points_to_analysis)) {
|
||||
if (heap.IgnoreBuffer(operand_buffer)) {
|
||||
if (IgnoreBuffer(operand_buffer)) {
|
||||
continue;
|
||||
}
|
||||
live_buffers[operand_buffer].erase(instruction);
|
||||
@ -142,10 +173,10 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run(
|
||||
// happen before dead or operand buffers are freed; the instruction reads
|
||||
// the operand buffers to produce its output.
|
||||
//
|
||||
// INVARIANT: Either heap.Alloc or heap.ShareBuffer will be called for each
|
||||
// buffer that we should assign.
|
||||
// INVARIANT: Either Alloc or ShareBuffer will be called for each buffer
|
||||
// that we should assign.
|
||||
for (const LogicalBuffer* buffer : buffers_defined_by_instruction) {
|
||||
if (heap.IgnoreBuffer(buffer)) {
|
||||
if (IgnoreBuffer(buffer)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -159,24 +190,50 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run(
|
||||
CanShareOperandBufferWithUser(
|
||||
operand_buffer->instruction(), operand_buffer->index(),
|
||||
buffer->instruction(), buffer->index(), points_to_analysis)) {
|
||||
heap.ShareBuffer(buffer, operand_buffer);
|
||||
ShareBuffer(buffer, operand_buffer);
|
||||
shared = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!shared) {
|
||||
heap.Alloc(buffer);
|
||||
Alloc(buffer);
|
||||
}
|
||||
}
|
||||
|
||||
// If the whole module is sequential, we can save memory by running the
|
||||
// heap-simulation for sub-computations inline. E.g. the buffers for the
|
||||
// condition and body of a kWhile instruction are only live for the duration
|
||||
// of the instruction itself.
|
||||
//
|
||||
// The order that the sub-computations are simulated does not affect
|
||||
// correctness; since the whole module is sequential, we know that the
|
||||
// sub-computations will never be run concurrently.
|
||||
if (module_sequence != nullptr) {
|
||||
if (instruction->opcode() == HloOpcode::kCall ||
|
||||
instruction->opcode() == HloOpcode::kWhile) {
|
||||
for (const HloComputation* called_computation :
|
||||
instruction->called_computations()) {
|
||||
const std::vector<const HloInstruction*>& called_sequence =
|
||||
FindOrDie(*module_sequence, called_computation);
|
||||
TF_RETURN_IF_ERROR(RunComputation(*called_computation,
|
||||
called_sequence, points_to_analysis,
|
||||
module_sequence));
|
||||
}
|
||||
}
|
||||
|
||||
// Other sub-computations (e.g. Map, Reduce, ...) are skipped; they are
|
||||
// assigned "thread-local" allocations, meaning their buffers are not
|
||||
// allocated up-front at the beginning of the computation.
|
||||
}
|
||||
|
||||
// Free buffers that are no longer live. This is the earliest point that we
|
||||
// can de-allocate; right after the last use of the buffer.
|
||||
for (const LogicalBuffer* buffer : dead_buffers_to_free) {
|
||||
heap.Free(buffer);
|
||||
Free(buffer);
|
||||
}
|
||||
for (const LogicalBuffer* buffer : operand_buffers_to_free) {
|
||||
heap.Free(buffer);
|
||||
Free(buffer);
|
||||
}
|
||||
}
|
||||
|
||||
@ -187,10 +244,10 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run(
|
||||
const FlatSet<const HloInstruction*>& pending = buffer_pending.second;
|
||||
CHECK_EQ(pending.size(), 1) << *buffer;
|
||||
CHECK(*pending.begin() == nullptr) << *buffer;
|
||||
heap.Free(buffer);
|
||||
Free(buffer);
|
||||
}
|
||||
|
||||
return heap.Finish();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
HeapSimulator::HeapSimulator(
|
||||
@ -309,6 +366,11 @@ HeapSimulator::Result HeapSimulator::Finish() {
|
||||
result.chunk_map.emplace(buffer, chunk);
|
||||
}
|
||||
}
|
||||
// If we were told to assign specific buffers, make sure we've assigned
|
||||
// exactly that many buffers.
|
||||
if (buffers_to_assign_ != nullptr) {
|
||||
CHECK_EQ(buffers_to_assign_->size(), result.chunk_map.size());
|
||||
}
|
||||
}
|
||||
|
||||
// Fragmentation is the difference between the actual and ideal sizes.
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
|
||||
#include "tensorflow/compiler/xla/service/logical_buffer.h"
|
||||
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
@ -63,17 +64,32 @@ class HeapSimulator {
|
||||
};
|
||||
|
||||
// Run the heap simulation with the given algorithm, assuming the given
|
||||
// sequential ordering of instructions. The 'instruction_sequence' must
|
||||
// contain a topologically-consistent total ordering of all instructions in
|
||||
// the computation. The result is invalid if instructions are not run in
|
||||
// exactly this sequence.
|
||||
// module_sequence, which must contain a topologically-consistent total
|
||||
// ordering of all instructions within each computation. The result is invalid
|
||||
// if instructions are not run in exactly this sequence.
|
||||
//
|
||||
// Running heap simulation on the whole module tends to save memory, compared
|
||||
// to running on a per-computation basis, since we can re-use buffer space for
|
||||
// called sub-computations.
|
||||
//
|
||||
// If 'buffers_to_assign' is provided, only those buffers are assigned
|
||||
// offsets, otherwise all buffers defined by the instructions are assigned.
|
||||
static StatusOr<Result> Run(
|
||||
std::unique_ptr<HeapAlgorithm> algorithm, const HloModule& module,
|
||||
const SequentialHloOrdering::HloModuleSequence& module_sequence,
|
||||
const TuplePointsToAnalysis& points_to_analysis,
|
||||
const LogicalBuffer::SizeFunction& size_fn,
|
||||
const tensorflow::gtl::FlatSet<const LogicalBuffer*>* buffers_to_assign =
|
||||
nullptr);
|
||||
|
||||
// Same as above, but runs on a single computation. The 'instruction_sequence'
|
||||
// must contain a topologically-consistent total ordering of all instructions
|
||||
// in the computation. The result is invalid if instructions are not run in
|
||||
// exactly this sequence.
|
||||
static StatusOr<Result> Run(
|
||||
std::unique_ptr<HeapAlgorithm> algorithm,
|
||||
const std::vector<const HloInstruction*>& instruction_sequence,
|
||||
const HloComputation& computation,
|
||||
const std::vector<const HloInstruction*>& instruction_sequence,
|
||||
const TuplePointsToAnalysis& points_to_analysis,
|
||||
const LogicalBuffer::SizeFunction& size_fn,
|
||||
const tensorflow::gtl::FlatSet<const LogicalBuffer*>* buffers_to_assign =
|
||||
@ -86,6 +102,12 @@ class HeapSimulator {
|
||||
const tensorflow::gtl::FlatSet<const LogicalBuffer*>* buffers_to_assign);
|
||||
~HeapSimulator();
|
||||
|
||||
Status RunComputation(
|
||||
const HloComputation& computation,
|
||||
const std::vector<const HloInstruction*>& instruction_sequence,
|
||||
const TuplePointsToAnalysis& points_to_analysis,
|
||||
const SequentialHloOrdering::HloModuleSequence* module_sequence);
|
||||
|
||||
bool IgnoreBuffer(const LogicalBuffer* buffer) const;
|
||||
void Alloc(const LogicalBuffer* buffer);
|
||||
void Free(const LogicalBuffer* buffer);
|
||||
|
@ -19,13 +19,16 @@ limitations under the License.
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
|
||||
#include "tensorflow/compiler/xla/service/logical_buffer.h"
|
||||
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
@ -69,6 +72,7 @@ class HeapCallRecorder : public HeapAlgorithm {
|
||||
// sequence against an expected sequence.
|
||||
class HeapSimulatorTracker {
|
||||
public:
|
||||
// Constructor for testing a single entry computation.
|
||||
HeapSimulatorTracker(
|
||||
const string& name, std::unique_ptr<HloComputation> computation,
|
||||
const std::vector<const HloInstruction*>& instruction_sequence) {
|
||||
@ -83,12 +87,48 @@ class HeapSimulatorTracker {
|
||||
auto zero_size = [](const LogicalBuffer& buffer) { return 0; };
|
||||
auto algorithm = MakeUnique<DecreasingSizeRunsHeap>(
|
||||
MakeUnique<HeapCallRecorder>(&actual_calls_));
|
||||
result_ = HeapSimulator::Run(std::move(algorithm), instruction_sequence,
|
||||
*module_->entry_computation(),
|
||||
*points_to_analysis_, zero_size)
|
||||
result_ = HeapSimulator::Run(
|
||||
std::move(algorithm), *module_->entry_computation(),
|
||||
instruction_sequence, *points_to_analysis_, zero_size)
|
||||
.ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
explicit HeapSimulatorTracker(const string& name) {
|
||||
module_ = MakeUnique<HloModule>(name);
|
||||
}
|
||||
|
||||
// Similar to the single entry computation constructor above, but runs the
|
||||
// simulation over the entire module.
|
||||
void RunWholeModule(
|
||||
const std::vector<const HloInstruction*>& full_module_sequence) {
|
||||
points_to_analysis_ =
|
||||
TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
|
||||
|
||||
// Construct the module sequence grouped by computation.
|
||||
SequentialHloOrdering::HloModuleSequence module_sequence;
|
||||
tensorflow::gtl::FlatMap<const HloInstruction*, int> reverse_position;
|
||||
for (int i = 0; i < full_module_sequence.size(); ++i) {
|
||||
const HloInstruction* instruction = full_module_sequence[i];
|
||||
module_sequence[instruction->parent()].push_back(instruction);
|
||||
reverse_position[instruction] = full_module_sequence.size() - i;
|
||||
}
|
||||
|
||||
// Hack the size_fn so that it returns a decreasing value as we step through
|
||||
// the sequence. This lets us ensure the Alloc calls are in the sequence
|
||||
// order. The Free calls are sorted by LogicalBuffer.id, which is at least
|
||||
// deterministic.
|
||||
auto size_fn = [&reverse_position](const LogicalBuffer& buffer) {
|
||||
return reverse_position[buffer.instruction()];
|
||||
};
|
||||
auto algorithm = MakeUnique<DecreasingSizeRunsHeap>(
|
||||
MakeUnique<HeapCallRecorder>(&actual_calls_));
|
||||
result_ = HeapSimulator::Run(std::move(algorithm), *module_,
|
||||
module_sequence, *points_to_analysis_, size_fn)
|
||||
.ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
HloModule* module() { return module_.get(); }
|
||||
|
||||
// Returns the buffer defined at the given instruction and index.
|
||||
const LogicalBuffer* BufferAt(const HloInstruction* instruction,
|
||||
const ShapeIndex& index) const {
|
||||
@ -358,6 +398,86 @@ TEST_F(HeapSimulatorTest, MultiplyDotDotTuple) {
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(HeapSimulatorTest, WholeModule) {
|
||||
HeapSimulatorTracker tracker(TestName());
|
||||
|
||||
const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
|
||||
const Shape tuple_shape =
|
||||
ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape});
|
||||
|
||||
auto cond_builder = HloComputation::Builder("WhileCond");
|
||||
HloInstruction* cond_param = cond_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
|
||||
HloInstruction* cond_iter = cond_builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0));
|
||||
HloInstruction* cond_data = cond_builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
|
||||
HloInstruction* cond_lt = cond_builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}),
|
||||
HloOpcode::kLt, cond_iter, cond_data));
|
||||
HloComputation* cond_computation =
|
||||
tracker.module()->AddEmbeddedComputation(cond_builder.Build());
|
||||
|
||||
auto body_builder = HloComputation::Builder("WhileBody");
|
||||
HloInstruction* body_param = body_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
|
||||
HloComputation* body_computation =
|
||||
tracker.module()->AddEmbeddedComputation(body_builder.Build());
|
||||
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
HloInstruction* param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, tuple_shape, "param"));
|
||||
HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
|
||||
tuple_shape, cond_computation, body_computation, param));
|
||||
tracker.module()->AddEntryComputation(builder.Build());
|
||||
|
||||
tracker.RunWholeModule(
|
||||
{param, while_op, body_param, cond_param, cond_iter, cond_data, cond_lt});
|
||||
tracker.ExpectCallSequence({
|
||||
// The entry computation param and while_op are allocated first.
|
||||
{kAlloc, tracker.BufferAt(param, {})},
|
||||
{kAlloc, tracker.BufferAt(param, {0})},
|
||||
{kAlloc, tracker.BufferAt(param, {1})},
|
||||
{kAlloc, tracker.BufferAt(while_op, {})},
|
||||
{kAlloc, tracker.BufferAt(while_op, {0})},
|
||||
{kAlloc, tracker.BufferAt(while_op, {1})},
|
||||
|
||||
// Now the while body param is allocated and freed.
|
||||
{kAlloc, tracker.BufferAt(body_param, {})},
|
||||
{kAlloc, tracker.BufferAt(body_param, {0})},
|
||||
{kAlloc, tracker.BufferAt(body_param, {1})},
|
||||
{kFree, tracker.BufferAt(body_param, {})},
|
||||
{kFree, tracker.BufferAt(body_param, {0})},
|
||||
{kFree, tracker.BufferAt(body_param, {1})},
|
||||
|
||||
// Now the while cond param is allocated. The GTE instructions just alias
|
||||
// the param elements, so the param tuple can immediately be freed.
|
||||
{kAlloc, tracker.BufferAt(cond_param, {})},
|
||||
{kAlloc, tracker.BufferAt(cond_param, {0})},
|
||||
{kAlloc, tracker.BufferAt(cond_param, {1})},
|
||||
{kFree, tracker.BufferAt(cond_param, {})},
|
||||
|
||||
// Now the final cond less-than buffer is allocated.
|
||||
{kAlloc, tracker.BufferAt(cond_lt, {})},
|
||||
|
||||
// The order of the remaining Free calls is based on the LogicalBuffer.id,
|
||||
// which is deterministic, but not obvious.
|
||||
{kFree, tracker.BufferAt(param, {})},
|
||||
{kFree, tracker.BufferAt(param, {0})},
|
||||
{kFree, tracker.BufferAt(param, {1})},
|
||||
|
||||
{kFree, tracker.BufferAt(while_op, {})},
|
||||
{kFree, tracker.BufferAt(while_op, {0})},
|
||||
{kFree, tracker.BufferAt(while_op, {1})},
|
||||
|
||||
{kFree, tracker.BufferAt(cond_param, {0})},
|
||||
{kFree, tracker.BufferAt(cond_param, {1})},
|
||||
{kFree, tracker.BufferAt(cond_lt, {})},
|
||||
|
||||
{kFinish, nullptr},
|
||||
});
|
||||
}
|
||||
|
||||
// Base class for heap algorithm tests.
|
||||
class HeapAlgorithmTestBase : public ::testing::Test {
|
||||
protected:
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
|
||||
namespace op = xla::testing::opcode_matchers;
|
||||
@ -49,8 +50,9 @@ TEST_F(HloConstantFoldingTest, ConvertF32ToS64) {
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(), op::Convert(input));
|
||||
|
||||
HloConstantFolding simplifier;
|
||||
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||
HloConstantFolding const_folder;
|
||||
TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get()));
|
||||
EXPECT_TRUE(result);
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(), op::Constant());
|
||||
EXPECT_EQ(LiteralUtil::GetFirstElement<int64>(
|
||||
@ -70,8 +72,9 @@ TEST_F(HloConstantFoldingTest, ConvertS64ToF32) {
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(), op::Convert(input));
|
||||
|
||||
HloConstantFolding simplifier;
|
||||
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||
HloConstantFolding const_folder;
|
||||
TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get()));
|
||||
EXPECT_TRUE(result);
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(), op::Constant());
|
||||
EXPECT_EQ(LiteralUtil::GetFirstElement<float>(
|
||||
@ -91,8 +94,9 @@ TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) {
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(), op::Convert(input));
|
||||
|
||||
HloConstantFolding simplifier;
|
||||
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||
HloConstantFolding const_folder;
|
||||
TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get()));
|
||||
EXPECT_TRUE(result);
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(), op::Constant());
|
||||
EXPECT_EQ(
|
||||
@ -131,11 +135,12 @@ TEST_F(HloConstantFoldingTest, Concatenate) {
|
||||
Shape shape = ShapeUtil::MakeShape(F32, dimensions);
|
||||
builder.AddInstruction(HloInstruction::CreateConcatenate(
|
||||
shape, operands, test_config.concat_dimension));
|
||||
HloModule module(TestName());
|
||||
auto computation = module.AddEntryComputation(builder.Build());
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
HloConstantFolding simplifier;
|
||||
ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
|
||||
HloConstantFolding const_folder;
|
||||
TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get()));
|
||||
EXPECT_TRUE(result);
|
||||
|
||||
HloInstruction* root = computation->root_instruction();
|
||||
EXPECT_THAT(root, op::Constant());
|
||||
@ -148,22 +153,61 @@ TEST_F(HloConstantFoldingTest, Slice) {
|
||||
const int64 dimensions[] = {11, 8, 7, 5, 9};
|
||||
const int64 slice_start[] = {4, 2, 3, 1, 5};
|
||||
const int64 slice_limits[] = {10, 8, 6, 5, 9};
|
||||
auto literal = LiteralUtil::CreateFromDimensions(F32, dimensions);
|
||||
HloInstruction* lit_insn = builder.AddInstruction(
|
||||
TF_ASSIGN_OR_ASSERT_OK(auto literal,
|
||||
LiteralTestUtil::CreateRandomLiteral<F32>(
|
||||
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
|
||||
HloInstruction* literal_instruction = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(std::move(literal)));
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {6, 6, 3, 4, 4});
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateSlice(shape, lit_insn, slice_start, slice_limits));
|
||||
HloModule module(TestName());
|
||||
auto computation = module.AddEntryComputation(builder.Build());
|
||||
builder.AddInstruction(HloInstruction::CreateSlice(
|
||||
shape, literal_instruction, slice_start, slice_limits));
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
HloConstantFolding simplifier;
|
||||
ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
|
||||
HloConstantFolding const_folder;
|
||||
TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get()));
|
||||
EXPECT_TRUE(result);
|
||||
|
||||
HloInstruction* root = computation->root_instruction();
|
||||
EXPECT_THAT(root, op::Constant());
|
||||
EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape));
|
||||
}
|
||||
|
||||
TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
const int64 dimensions[] = {11, 8, 7, 5, 9};
|
||||
TF_ASSIGN_OR_ASSERT_OK(auto literal,
|
||||
LiteralTestUtil::CreateRandomLiteral<F32>(
|
||||
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
|
||||
auto literal_clone = LiteralUtil::CloneToUnique(*literal);
|
||||
HloInstruction* literal_instruction = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(std::move(literal)));
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5});
|
||||
const int64 permutation[] = {1, 2, 0, 4, 3};
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateTranspose(shape, literal_instruction, permutation));
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
HloConstantFolding const_folder;
|
||||
TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get()));
|
||||
EXPECT_TRUE(result);
|
||||
|
||||
HloInstruction* root = computation->root_instruction();
|
||||
EXPECT_THAT(root, op::Constant());
|
||||
EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape));
|
||||
|
||||
using NativeT = typename primitive_util::PrimitiveTypeToNative<F32>::type;
|
||||
bool matched = true;
|
||||
LiteralUtil::EachCell<NativeT>(
|
||||
root->literal(),
|
||||
[&](tensorflow::gtl::ArraySlice<int64> indices, NativeT value) {
|
||||
std::vector<int64> rindexes = Permute(permutation, indices);
|
||||
matched = matched && (value == LiteralUtil::Get<NativeT>(*literal_clone,
|
||||
rindexes));
|
||||
});
|
||||
EXPECT_TRUE(matched);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -1570,7 +1570,9 @@ string HloInstruction::ToCategory() const {
|
||||
return "non-elementwise fusion";
|
||||
}
|
||||
case FusionKind::kInput:
|
||||
return "reduce fusion";
|
||||
return "input fusion";
|
||||
case FusionKind::kOutput:
|
||||
return "output fusion";
|
||||
case FusionKind::kTransposeDot:
|
||||
return "dot fusion";
|
||||
case FusionKind::kConvBackwardFilter:
|
||||
@ -1618,7 +1620,6 @@ bool HloInstruction::IsFusable() const {
|
||||
|
||||
// Some kinds of instructions don't make sense to fuse.
|
||||
switch (opcode_) {
|
||||
case HloOpcode::kFusion:
|
||||
case HloOpcode::kInfeed:
|
||||
case HloOpcode::kOutfeed:
|
||||
case HloOpcode::kParameter:
|
||||
@ -2186,6 +2187,8 @@ string ToString(HloInstruction::FusionKind kind) {
|
||||
return "kLoop";
|
||||
case HloInstruction::FusionKind::kInput:
|
||||
return "kInput";
|
||||
case HloInstruction::FusionKind::kOutput:
|
||||
return "kOutput";
|
||||
case HloInstruction::FusionKind::kTransposeDot:
|
||||
return "kTransposeDot";
|
||||
case HloInstruction::FusionKind::kConvBackwardFilter:
|
||||
|
@ -54,7 +54,8 @@ class HloInstruction {
|
||||
public:
|
||||
enum class FusionKind {
|
||||
kLoop, // Fused into a loop.
|
||||
kInput, // Fused into a reduction kernel.
|
||||
kInput, // Op's input is fused into the op itself.
|
||||
kOutput, // Op's output is fused into the op itself.
|
||||
kTransposeDot, // Fused into a dot with transposed operands.
|
||||
kConvBackwardFilter, // Fused into a backward filter convolution.
|
||||
kConvBackwardInput, // Fused into a backward input convolution.
|
||||
|
@ -221,23 +221,6 @@ string SequentialHloOrdering::ToString() const {
|
||||
return tensorflow::str_util::Join(pieces, "\n");
|
||||
}
|
||||
|
||||
namespace {
|
||||
StatusOr<int64> MinimumMemoryForSequence(
|
||||
const HloComputation& computation,
|
||||
const std::vector<const HloInstruction*>& sequence,
|
||||
const TuplePointsToAnalysis& points_to_analysis,
|
||||
const LogicalBuffer::SizeFunction& size_function) {
|
||||
// The absolute minimum memory required for a given sequence of instructions
|
||||
// is determined by the sequence of Alloc and Free calls on a simulated heap,
|
||||
// ignoring fragmentation.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HeapSimulator::Result result,
|
||||
HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), sequence,
|
||||
computation, points_to_analysis, size_function));
|
||||
return result.heap_size;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
StatusOr<int64> MinimumMemoryForSequence(
|
||||
const SequentialHloOrdering::HloModuleSequence& module_sequence,
|
||||
const LogicalBuffer::SizeFunction& size_function) {
|
||||
@ -249,17 +232,16 @@ StatusOr<int64> MinimumMemoryForSequence(
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
|
||||
TuplePointsToAnalysis::Run(module));
|
||||
|
||||
int64 total_memory = 0;
|
||||
for (const auto& pair : module_sequence) {
|
||||
const HloComputation* computation = pair.first;
|
||||
const std::vector<const HloInstruction*>& sequence = pair.second;
|
||||
// The absolute minimum memory required for a given sequence of instructions
|
||||
// is determined by the sequence of Alloc and Free calls on a simulated heap,
|
||||
// ignoring fragmentation. We run the heap simulation on the whole module,
|
||||
// rather than summing each computation, since it gives us a better lower
|
||||
// bound, by minimizing the liveness of sub-computations.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
const int64 memory,
|
||||
MinimumMemoryForSequence(*computation, sequence, *points_to_analysis,
|
||||
size_function));
|
||||
total_memory += memory;
|
||||
}
|
||||
return total_memory;
|
||||
HeapSimulator::Result result,
|
||||
HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), *module,
|
||||
module_sequence, *points_to_analysis, size_function));
|
||||
return result.heap_size;
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -516,6 +498,18 @@ StatusOr<std::vector<const HloInstruction*>> RunDFSMemoryScheduler(
|
||||
return sequence;
|
||||
}
|
||||
|
||||
StatusOr<int64> MinimumMemoryForComputation(
|
||||
const HloComputation& computation,
|
||||
const std::vector<const HloInstruction*>& sequence,
|
||||
const TuplePointsToAnalysis& points_to_analysis,
|
||||
const LogicalBuffer::SizeFunction& size_function) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HeapSimulator::Result result,
|
||||
HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), computation,
|
||||
sequence, points_to_analysis, size_function));
|
||||
return result.heap_size;
|
||||
}
|
||||
|
||||
StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence(
|
||||
const HloComputation& computation,
|
||||
const TuplePointsToAnalysis& points_to_analysis,
|
||||
@ -523,13 +517,17 @@ StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence(
|
||||
// We try both a list-scheduler based ordering and a DFS based ordering, and
|
||||
// choose whichever returns a lower min-memory, not accounting for
|
||||
// fragmentation.
|
||||
//
|
||||
// Note that this is just a heuristic. One obvious inaccuracy is that the
|
||||
// memory required for sub-computations might be different when considered
|
||||
// within the caller's context. But it's good enough for now.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::vector<const HloInstruction*> list_sequence,
|
||||
ListScheduler::Run(computation, points_to_analysis, size_function));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
const int64 list_memory,
|
||||
MinimumMemoryForSequence(computation, list_sequence, points_to_analysis,
|
||||
size_function));
|
||||
MinimumMemoryForComputation(computation, list_sequence,
|
||||
points_to_analysis, size_function));
|
||||
VLOG(2) << "Min-memory list sequence: " << list_memory << " bytes";
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
@ -537,7 +535,7 @@ StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence(
|
||||
RunDFSMemoryScheduler(computation, points_to_analysis, size_function));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
const int64 dfs_memory,
|
||||
MinimumMemoryForSequence(computation, dfs_sequence, points_to_analysis,
|
||||
MinimumMemoryForComputation(computation, dfs_sequence, points_to_analysis,
|
||||
size_function));
|
||||
VLOG(2) << "Min-memory dfs sequence: " << dfs_memory << " bytes";
|
||||
|
||||
|
@ -155,6 +155,65 @@ TEST_F(HloOrderingTest, InstructionsInDifferentComputations) {
|
||||
EXPECT_FALSE(ordering.ExecutesBefore(y, c));
|
||||
}
|
||||
|
||||
class MinimumMemoryForSequenceTest : public HloTestBase {};
|
||||
|
||||
TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
|
||||
HloModule module(TestName());
|
||||
const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
|
||||
const Shape tuple_shape =
|
||||
ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape});
|
||||
|
||||
auto cond_builder = HloComputation::Builder("WhileCond");
|
||||
// Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
|
||||
HloInstruction* cond_param = cond_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
|
||||
HloInstruction* cond_iter = cond_builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0));
|
||||
HloInstruction* cond_data = cond_builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
|
||||
// Free cond_param[] (16 bytes), Alloc PRED[] (1 byte)
|
||||
HloInstruction* cond_lt = cond_builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}),
|
||||
HloOpcode::kLt, cond_iter, cond_data));
|
||||
HloComputation* cond_computation =
|
||||
module.AddEmbeddedComputation(cond_builder.Build());
|
||||
|
||||
auto body_builder = HloComputation::Builder("WhileBody");
|
||||
// Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
|
||||
HloInstruction* body_param = body_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
|
||||
HloComputation* body_computation =
|
||||
module.AddEmbeddedComputation(body_builder.Build());
|
||||
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
// Entry params: 8 bytes (4 bytes per param), TOTAL=8
|
||||
HloInstruction* iter = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, scalar_shape, "param_iter"));
|
||||
HloInstruction* data = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, scalar_shape, "param_data"));
|
||||
// Tuple: 16 bytes (8 bytes per pointer), TOTAL=24
|
||||
HloInstruction* tuple =
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({iter, data}));
|
||||
// While: 8 bytes (4 bytes per element), TOTAL=32
|
||||
// Both cond and body use a max of 24 bytes, TOTAL=56
|
||||
HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
|
||||
tuple_shape, cond_computation, body_computation, tuple));
|
||||
HloComputation* entry_computation =
|
||||
module.AddEntryComputation(builder.Build());
|
||||
|
||||
auto size_fn = [](const LogicalBuffer& buffer) {
|
||||
return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
|
||||
};
|
||||
|
||||
SequentialHloOrdering::HloModuleSequence module_sequence;
|
||||
module_sequence[cond_computation] = {cond_param, cond_iter, cond_data,
|
||||
cond_lt};
|
||||
module_sequence[body_computation] = {body_param};
|
||||
module_sequence[entry_computation] = {iter, data, tuple, while_op};
|
||||
EXPECT_EQ(56,
|
||||
MinimumMemoryForSequence(module_sequence, size_fn).ValueOrDie());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace xla
|
||||
|
@ -1160,28 +1160,25 @@ StatusOr<bool> HloRematerialization::Run(
|
||||
TuplePointsToAnalysis::Run(
|
||||
module, /*include_loop_fusion_instructions=*/true));
|
||||
|
||||
// Adjust memory limit to account for the parameter and output of the entry
|
||||
// Adjust memory limit to account for the output of the entry
|
||||
// computation. This is necessary because the per-computation accounting in
|
||||
// MemoryUsageTracker do not include parameters and output as these are
|
||||
// typically allocated by the caller. With this adjustment the memory limit
|
||||
// accounts for the size of all HLO instructions (parameters, output
|
||||
// instructions, etc).
|
||||
auto total_size = [this](const HloInstruction* instruction) {
|
||||
int64 total_size = 0;
|
||||
for (const LogicalBuffer* logical_buffer :
|
||||
points_to_analysis_->GetBuffersDefinedByInstruction(instruction)) {
|
||||
total_size += size_function_(logical_buffer->shape());
|
||||
}
|
||||
return total_size;
|
||||
};
|
||||
const HloComputation* entry_computation = module->entry_computation();
|
||||
memory_limit_bytes -= total_size(entry_computation->root_instruction());
|
||||
for (const HloInstruction* param :
|
||||
entry_computation->parameter_instructions()) {
|
||||
memory_limit_bytes -= total_size(param);
|
||||
}
|
||||
VLOG(1) << "Adjusted memory limit accounting for parameters and output: "
|
||||
<< HumanReadableNumBytes(memory_limit_bytes);
|
||||
// MemoryUsageTracker do not include output as these are typically allocated
|
||||
// by the caller.
|
||||
int64 module_output_size = 0;
|
||||
ShapeUtil::ForEachSubshape(
|
||||
module->entry_computation()->root_instruction()->shape(),
|
||||
[&module_output_size, this](const Shape& subshape,
|
||||
const ShapeIndex& /*index*/) {
|
||||
module_output_size += size_function_(subshape);
|
||||
return Status::OK();
|
||||
})
|
||||
.IgnoreError();
|
||||
|
||||
const int64 adjusted_memory_limit_bytes =
|
||||
memory_limit_bytes - module_output_size;
|
||||
VLOG(1) << "Adjusted memory limit accounting for output ("
|
||||
<< HumanReadableNumBytes(module_output_size)
|
||||
<< "): " << HumanReadableNumBytes(adjusted_memory_limit_bytes);
|
||||
|
||||
XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString());
|
||||
// Create initial sequence of HLO instructions.
|
||||
@ -1204,8 +1201,13 @@ StatusOr<bool> HloRematerialization::Run(
|
||||
return Status::OK();
|
||||
}));
|
||||
|
||||
// The peak memory usage of the module equals the peak memory use of the entry
|
||||
// computation plus the output size of the computation. This is because the
|
||||
// peak memory for a computation does not include the output as this is
|
||||
// typically accounted for in the caller.
|
||||
const int64 before_peak_memory =
|
||||
computation_peak_memory_.at(module->entry_computation());
|
||||
computation_peak_memory_.at(module->entry_computation()) +
|
||||
module_output_size;
|
||||
VLOG(1) << "Peak memory usage of module (before): "
|
||||
<< HumanReadableNumBytes(before_peak_memory);
|
||||
|
||||
@ -1216,9 +1218,9 @@ StatusOr<bool> HloRematerialization::Run(
|
||||
|
||||
// Subcomputations called by the entry computation will also be
|
||||
// rematerialized.
|
||||
TF_ASSIGN_OR_RETURN(bool changed,
|
||||
RematerializeComputation(module->entry_computation(),
|
||||
sequence, memory_limit_bytes));
|
||||
TF_ASSIGN_OR_RETURN(bool changed, RematerializeComputation(
|
||||
module->entry_computation(), sequence,
|
||||
adjusted_memory_limit_bytes));
|
||||
|
||||
// Rematerialization can introduce dead code. This occurs if all uses of an
|
||||
// instruction are replaced with rematerializations of the instruction.
|
||||
@ -1257,7 +1259,8 @@ StatusOr<bool> HloRematerialization::Run(
|
||||
<< " instructions in module " << module->name() << "; "
|
||||
<< net_instructions_added_ << " net instructions added";
|
||||
const int64 current_peak_memory =
|
||||
computation_peak_memory_.at(module->entry_computation());
|
||||
computation_peak_memory_.at(module->entry_computation()) +
|
||||
module_output_size;
|
||||
VLOG(1) << "Peak memory usage of module now "
|
||||
<< HumanReadableNumBytes(current_peak_memory) << " ("
|
||||
<< current_peak_memory << " bytes), was "
|
||||
|
@ -1928,6 +1928,12 @@ HloInstruction* ComputationLowerer::Visit(
|
||||
|
||||
const OperationRequest& request =
|
||||
session_computation_.requests().at(handle.handle());
|
||||
auto add_instruction = [&](std::unique_ptr<HloInstruction> instruction) {
|
||||
HloInstruction* hlo_instruction =
|
||||
hlo_builder_.AddInstruction(std::move(instruction));
|
||||
hlo_instruction->set_metadata(request.request().metadata());
|
||||
return hlo_instruction;
|
||||
};
|
||||
HloInstruction* hlo_instruction;
|
||||
switch (request.request().op_case()) {
|
||||
case OpRequest::kRngRequest: {
|
||||
@ -1936,7 +1942,7 @@ HloInstruction* ComputationLowerer::Visit(
|
||||
for (const ComputationDataHandle& param : rng_request.parameter()) {
|
||||
parameters.push_back(Visit(param, visited));
|
||||
}
|
||||
hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateRng(
|
||||
hlo_instruction = add_instruction(HloInstruction::CreateRng(
|
||||
request.output_shape(), rng_request.distribution(), parameters));
|
||||
break;
|
||||
}
|
||||
@ -1944,8 +1950,7 @@ HloInstruction* ComputationLowerer::Visit(
|
||||
case OpRequest::kConstantRequest: {
|
||||
const ConstantRequest& constant_request =
|
||||
request.request().constant_request();
|
||||
hlo_instruction =
|
||||
hlo_builder_.AddInstruction(HloInstruction::CreateConstant(
|
||||
hlo_instruction = add_instruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CloneToUnique(constant_request.literal())));
|
||||
break;
|
||||
}
|
||||
@ -1955,17 +1960,15 @@ HloInstruction* ComputationLowerer::Visit(
|
||||
request.request().get_tuple_element_request();
|
||||
HloInstruction* operand =
|
||||
Visit(get_tuple_element_request.operand(), visited);
|
||||
hlo_instruction =
|
||||
hlo_builder_.AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
request.output_shape(), operand,
|
||||
get_tuple_element_request.index()));
|
||||
hlo_instruction = add_instruction(HloInstruction::CreateGetTupleElement(
|
||||
request.output_shape(), operand, get_tuple_element_request.index()));
|
||||
break;
|
||||
}
|
||||
|
||||
case OpRequest::kSliceRequest: {
|
||||
const SliceRequest& slice_request = request.request().slice_request();
|
||||
HloInstruction* operand = Visit(slice_request.operand(), visited);
|
||||
hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateSlice(
|
||||
hlo_instruction = add_instruction(HloInstruction::CreateSlice(
|
||||
request.output_shape(), operand,
|
||||
AsInt64Slice(slice_request.start_indices()),
|
||||
AsInt64Slice(slice_request.limit_indices())));
|
||||
@ -1979,8 +1982,7 @@ HloInstruction* ComputationLowerer::Visit(
|
||||
HloInstruction* start_indices =
|
||||
Visit(dynamic_slice_request.start_indices(), visited);
|
||||
|
||||
hlo_instruction =
|
||||
hlo_builder_.AddInstruction(HloInstruction::CreateDynamicSlice(
|
||||
hlo_instruction = add_instruction(HloInstruction::CreateDynamicSlice(
|
||||
request.output_shape(), operand, start_indices,
|
||||
AsInt64Slice(dynamic_slice_request.slice_sizes())));
|
||||
break;
|
||||
@ -1996,7 +1998,7 @@ HloInstruction* ComputationLowerer::Visit(
|
||||
HloInstruction* start_indices =
|
||||
Visit(dynamic_update_slice_request.start_indices(), visited);
|
||||
hlo_instruction =
|
||||
hlo_builder_.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
|
||||
add_instruction(HloInstruction::CreateDynamicUpdateSlice(
|
||||
request.output_shape(), operand, update, start_indices));
|
||||
break;
|
||||
}
|
||||
@ -2010,9 +2012,8 @@ HloInstruction* ComputationLowerer::Visit(
|
||||
HloInstruction* operand = Visit(handle, visited);
|
||||
operands.push_back(operand);
|
||||
}
|
||||
hlo_instruction = hlo_builder_.AddInstruction(
|
||||
HloInstruction::CreateConcatenate(request.output_shape(), operands,
|
||||
concatenate_request.dimension()));
|
||||
hlo_instruction = add_instruction(HloInstruction::CreateConcatenate(
|
||||
request.output_shape(), operands, concatenate_request.dimension()));
|
||||
break;
|
||||
}
|
||||
|
||||
@ -2021,8 +2022,7 @@ HloInstruction* ComputationLowerer::Visit(
|
||||
request.request().convolve_request();
|
||||
HloInstruction* lhs = Visit(convolve_request.lhs(), visited);
|
||||
HloInstruction* rhs = Visit(convolve_request.rhs(), visited);
|
||||
hlo_instruction =
|
||||
hlo_builder_.AddInstruction(HloInstruction::CreateConvolve(
|
||||
hlo_instruction = add_instruction(HloInstruction::CreateConvolve(
|
||||
request.output_shape(), lhs, rhs, convolve_request.window(),
|
||||
convolve_request.dimension_numbers()));
|
||||
break;
|
||||
@ -2033,16 +2033,14 @@ HloInstruction* ComputationLowerer::Visit(
|
||||
request.request().cross_replica_sum_request();
|
||||
HloInstruction* operand =
|
||||
Visit(cross_replica_sum_request.operand(), visited);
|
||||
hlo_instruction =
|
||||
hlo_builder_.AddInstruction(HloInstruction::CreateCrossReplicaSum(
|
||||
hlo_instruction = add_instruction(HloInstruction::CreateCrossReplicaSum(
|
||||
request.output_shape(), operand));
|
||||
break;
|
||||
}
|
||||
|
||||
case OpRequest::kInfeedRequest: {
|
||||
const InfeedRequest& infeed_request = request.request().infeed_request();
|
||||
hlo_instruction =
|
||||
hlo_builder_.AddInstruction(HloInstruction::CreateInfeed(
|
||||
hlo_instruction = add_instruction(HloInstruction::CreateInfeed(
|
||||
request.output_shape(), infeed_request.config()));
|
||||
break;
|
||||
}
|
||||
@ -2051,9 +2049,8 @@ HloInstruction* ComputationLowerer::Visit(
|
||||
const OutfeedRequest& outfeed_request =
|
||||
request.request().outfeed_request();
|
||||
HloInstruction* operand = Visit(outfeed_request.operand(), visited);
|
||||
hlo_instruction = hlo_builder_.AddInstruction(
|
||||
HloInstruction::CreateOutfeed(outfeed_request.shape(), operand,
|
||||
outfeed_request.outfeed_config()));
|
||||
hlo_instruction = add_instruction(HloInstruction::CreateOutfeed(
|
||||
outfeed_request.shape(), operand, outfeed_request.outfeed_config()));
|
||||
break;
|
||||
}
|
||||
|
||||
@ -2069,7 +2066,7 @@ HloInstruction* ComputationLowerer::Visit(
|
||||
request.embedded_computation_versions(0);
|
||||
HloComputation* map_computation =
|
||||
ResolveComputation(map_request.to_apply(), map_version);
|
||||
hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateMap(
|
||||
hlo_instruction = add_instruction(HloInstruction::CreateMap(
|
||||
request.output_shape(), operands, map_computation));
|
||||
break;
|
||||
}
|
||||
@ -2083,8 +2080,7 @@ HloInstruction* ComputationLowerer::Visit(
|
||||
request.embedded_computation_versions(0);
|
||||
HloComputation* reduce_computation =
|
||||
ResolveComputation(reduce_request.to_apply(), reduce_version);
|
||||
hlo_instruction =
|
||||
hlo_builder_.AddInstruction(HloInstruction::CreateReduce(
|
||||
hlo_instruction = add_instruction(HloInstruction::CreateReduce(
|
||||
request.output_shape(), operand, init_value,
|
||||
AsInt64Slice(reduce_request.dimensions()), reduce_computation));
|
||||
break;
|
||||
@ -2101,8 +2097,7 @@ HloInstruction* ComputationLowerer::Visit(
|
||||
request.embedded_computation_versions(0);
|
||||
HloComputation* reduce_window_computation = ResolveComputation(
|
||||
reduce_window_request.to_apply(), reduce_window_version);
|
||||
hlo_instruction =
|
||||
hlo_builder_.AddInstruction(HloInstruction::CreateReduceWindow(
|
||||
hlo_instruction = add_instruction(HloInstruction::CreateReduceWindow(
|
||||
request.output_shape(), operand, init_value,
|
||||
reduce_window_request.window(), reduce_window_computation));
|
||||
break;
|
||||
@ -2126,8 +2121,7 @@ HloInstruction* ComputationLowerer::Visit(
|
||||
select_and_scatter_request.select(), select_version);
|
||||
HloComputation* scatter_computation = ResolveComputation(
|
||||
select_and_scatter_request.scatter(), scatter_version);
|
||||
hlo_instruction =
|
||||
hlo_builder_.AddInstruction(HloInstruction::CreateSelectAndScatter(
|
||||
hlo_instruction = add_instruction(HloInstruction::CreateSelectAndScatter(
|
||||
request.output_shape(), operand, select_computation,
|
||||
select_and_scatter_request.window(), source, init_value,
|
||||
scatter_computation));
|
||||
@ -2151,8 +2145,7 @@ HloInstruction* ComputationLowerer::Visit(
|
||||
ShapeUtil::Rank(request.output_shape()) -
|
||||
ShapeUtil::Rank(operand->shape()));
|
||||
}
|
||||
hlo_instruction =
|
||||
hlo_builder_.AddInstruction(HloInstruction::CreateBroadcast(
|
||||
hlo_instruction = add_instruction(HloInstruction::CreateBroadcast(
|
||||
request.output_shape(), operand, broadcast_dimensions));
|
||||
break;
|
||||
}
|
||||
@ -2165,14 +2158,13 @@ HloInstruction* ComputationLowerer::Visit(
|
||||
if (IsIdentityPermutation(AsInt64Slice(reshape_request.dimensions()))) {
|
||||
transposed = operand;
|
||||
} else {
|
||||
transposed =
|
||||
hlo_builder_.AddInstruction(HloInstruction::CreateTranspose(
|
||||
ShapeUtil::PermuteDimensions(InversePermutation(AsInt64Slice(
|
||||
reshape_request.dimensions())),
|
||||
transposed = add_instruction(HloInstruction::CreateTranspose(
|
||||
ShapeUtil::PermuteDimensions(
|
||||
InversePermutation(AsInt64Slice(reshape_request.dimensions())),
|
||||
operand->shape()),
|
||||
operand, AsInt64Slice(reshape_request.dimensions())));
|
||||
}
|
||||
hlo_instruction = hlo_builder_.AddInstruction(
|
||||
hlo_instruction = add_instruction(
|
||||
HloInstruction::CreateReshape(request.output_shape(), transposed));
|
||||
break;
|
||||
}
|
||||
@ -2181,10 +2173,9 @@ HloInstruction* ComputationLowerer::Visit(
|
||||
const TransposeRequest& transpose_request =
|
||||
request.request().transpose_request();
|
||||
HloInstruction* operand = Visit(transpose_request.operand(), visited);
|
||||
hlo_instruction =
|
||||
hlo_builder_.AddInstruction(HloInstruction::CreateTranspose(
|
||||
ShapeUtil::PermuteDimensions(InversePermutation(AsInt64Slice(
|
||||
transpose_request.dimensions())),
|
||||
hlo_instruction = add_instruction(HloInstruction::CreateTranspose(
|
||||
ShapeUtil::PermuteDimensions(
|
||||
InversePermutation(AsInt64Slice(transpose_request.dimensions())),
|
||||
operand->shape()),
|
||||
operand, AsInt64Slice(transpose_request.dimensions())));
|
||||
break;
|
||||
@ -2194,8 +2185,7 @@ HloInstruction* ComputationLowerer::Visit(
|
||||
const ReverseRequest& reverse_request =
|
||||
request.request().reverse_request();
|
||||
HloInstruction* operand = Visit(reverse_request.operand(), visited);
|
||||
hlo_instruction =
|
||||
hlo_builder_.AddInstruction(HloInstruction::CreateReverse(
|
||||
hlo_instruction = add_instruction(HloInstruction::CreateReverse(
|
||||
request.output_shape(), operand,
|
||||
AsInt64Slice(reverse_request.dimensions())));
|
||||
break;
|
||||
@ -2206,7 +2196,7 @@ HloInstruction* ComputationLowerer::Visit(
|
||||
HloInstruction* operand = Visit(pad_request.operand(), visited);
|
||||
HloInstruction* padding_value =
|
||||
Visit(pad_request.padding_value(), visited);
|
||||
hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreatePad(
|
||||
hlo_instruction = add_instruction(HloInstruction::CreatePad(
|
||||
request.output_shape(), operand, padding_value,
|
||||
pad_request.padding_config()));
|
||||
break;
|
||||
@ -2214,7 +2204,7 @@ HloInstruction* ComputationLowerer::Visit(
|
||||
|
||||
case OpRequest::kRecvRequest: {
|
||||
const RecvRequest& recv_request = request.request().recv_request();
|
||||
hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateRecv(
|
||||
hlo_instruction = add_instruction(HloInstruction::CreateRecv(
|
||||
request.output_shape(), recv_request.channel_handle().handle()));
|
||||
break;
|
||||
}
|
||||
@ -2222,8 +2212,7 @@ HloInstruction* ComputationLowerer::Visit(
|
||||
case OpRequest::kParameterRequest: {
|
||||
const ParameterRequest& parameter_request =
|
||||
request.request().parameter_request();
|
||||
hlo_instruction =
|
||||
hlo_builder_.AddInstruction(HloInstruction::CreateParameter(
|
||||
hlo_instruction = add_instruction(HloInstruction::CreateParameter(
|
||||
parameter_request.parameter(), request.output_shape(),
|
||||
parameter_request.name()));
|
||||
break;
|
||||
@ -2233,7 +2222,7 @@ HloInstruction* ComputationLowerer::Visit(
|
||||
const ConvertRequest& convert_request =
|
||||
request.request().convert_request();
|
||||
HloInstruction* operand = Visit(convert_request.operand(), visited);
|
||||
hlo_instruction = hlo_builder_.AddInstruction(
|
||||
hlo_instruction = add_instruction(
|
||||
HloInstruction::CreateConvert(request.output_shape(), operand));
|
||||
break;
|
||||
}
|
||||
@ -2250,7 +2239,7 @@ HloInstruction* ComputationLowerer::Visit(
|
||||
HloComputation* body =
|
||||
ResolveComputation(while_request.body(), body_version);
|
||||
HloInstruction* init = Visit(while_request.init(), visited);
|
||||
hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateWhile(
|
||||
hlo_instruction = add_instruction(HloInstruction::CreateWhile(
|
||||
request.output_shape(), condition, body, init));
|
||||
break;
|
||||
}
|
||||
@ -2262,8 +2251,7 @@ HloInstruction* ComputationLowerer::Visit(
|
||||
HloInstruction* rhs = Visit(ternary_op_request.rhs(), visited);
|
||||
HloInstruction* ehs = Visit(ternary_op_request.ehs(), visited);
|
||||
auto hlo_opcode = TernaryOperationToHloOpcode(ternary_op_request.triop());
|
||||
hlo_instruction =
|
||||
hlo_builder_.AddInstruction(HloInstruction::CreateTernary(
|
||||
hlo_instruction = add_instruction(HloInstruction::CreateTernary(
|
||||
request.output_shape(), hlo_opcode, lhs, rhs, ehs));
|
||||
break;
|
||||
}
|
||||
@ -2279,8 +2267,7 @@ HloInstruction* ComputationLowerer::Visit(
|
||||
}
|
||||
auto hlo_opcode =
|
||||
VariadicOperationToHloOpcode(variadic_op_request.varop());
|
||||
hlo_instruction =
|
||||
hlo_builder_.AddInstruction(HloInstruction::CreateVariadic(
|
||||
hlo_instruction = add_instruction(HloInstruction::CreateVariadic(
|
||||
request.output_shape(), hlo_opcode, operands));
|
||||
break;
|
||||
}
|
||||
@ -2296,7 +2283,7 @@ HloInstruction* ComputationLowerer::Visit(
|
||||
request.embedded_computation_versions(0);
|
||||
HloComputation* call_computation =
|
||||
ResolveComputation(call_request.to_apply(), call_version);
|
||||
hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateCall(
|
||||
hlo_instruction = add_instruction(HloInstruction::CreateCall(
|
||||
request.output_shape(), operands, call_computation));
|
||||
break;
|
||||
}
|
||||
@ -2308,8 +2295,7 @@ HloInstruction* ComputationLowerer::Visit(
|
||||
for (const ComputationDataHandle& operand : cc_request.operands()) {
|
||||
operands.push_back(Visit(operand, visited));
|
||||
}
|
||||
hlo_instruction =
|
||||
hlo_builder_.AddInstruction(HloInstruction::CreateCustomCall(
|
||||
hlo_instruction = add_instruction(HloInstruction::CreateCustomCall(
|
||||
cc_request.shape(), operands, cc_request.call_target_name()));
|
||||
break;
|
||||
}
|
||||
@ -2319,7 +2305,7 @@ HloInstruction* ComputationLowerer::Visit(
|
||||
request.request().unary_op_request();
|
||||
HloInstruction* operand = Visit(unary_op_request.operand(), visited);
|
||||
auto hlo_opcode = UnaryOperationToHloOpcode(unary_op_request.unop());
|
||||
hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateUnary(
|
||||
hlo_instruction = add_instruction(HloInstruction::CreateUnary(
|
||||
request.output_shape(), hlo_opcode, operand));
|
||||
break;
|
||||
}
|
||||
@ -2347,15 +2333,14 @@ HloInstruction* ComputationLowerer::Visit(
|
||||
// identical to the HLO broadcast semantics so the broadcast_dimensions
|
||||
// field can just be passed to the instruction builder.
|
||||
HloInstruction* broadcasted_operand =
|
||||
hlo_builder_.AddInstruction(HloInstruction::CreateBroadcast(
|
||||
add_instruction(HloInstruction::CreateBroadcast(
|
||||
broadcast_shape, operand_to_broadcast,
|
||||
AsInt64Slice(binary_op_request.broadcast_dimensions())));
|
||||
|
||||
lhs = (lhs == operand_to_broadcast) ? broadcasted_operand : lhs;
|
||||
rhs = (rhs == operand_to_broadcast) ? broadcasted_operand : rhs;
|
||||
}
|
||||
hlo_instruction =
|
||||
hlo_builder_.AddInstruction(HloInstruction::CreateBinary(
|
||||
hlo_instruction = add_instruction(HloInstruction::CreateBinary(
|
||||
request.output_shape(), hlo_opcode, lhs, rhs));
|
||||
break;
|
||||
}
|
||||
@ -2363,7 +2348,7 @@ HloInstruction* ComputationLowerer::Visit(
|
||||
case OpRequest::kTraceRequest: {
|
||||
const TraceRequest& trace_request = request.request().trace_request();
|
||||
HloInstruction* operand = Visit(trace_request.operand(), visited);
|
||||
hlo_instruction = hlo_builder_.AddInstruction(
|
||||
hlo_instruction = add_instruction(
|
||||
HloInstruction::CreateTrace(trace_request.tag(), operand));
|
||||
operand->set_tracing(hlo_instruction);
|
||||
break;
|
||||
@ -2372,7 +2357,7 @@ HloInstruction* ComputationLowerer::Visit(
|
||||
case OpRequest::kSendRequest: {
|
||||
const SendRequest& send_request = request.request().send_request();
|
||||
HloInstruction* operand = Visit(send_request.operand(), visited);
|
||||
hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateSend(
|
||||
hlo_instruction = add_instruction(HloInstruction::CreateSend(
|
||||
operand, send_request.channel_handle().handle()));
|
||||
break;
|
||||
}
|
||||
@ -2383,7 +2368,6 @@ HloInstruction* ComputationLowerer::Visit(
|
||||
default:
|
||||
LOG(FATAL) << "Unexpected request type: " << request.request().op_case();
|
||||
}
|
||||
hlo_instruction->set_metadata(request.request().metadata());
|
||||
(*visited)[handle.handle()] = hlo_instruction;
|
||||
return hlo_instruction;
|
||||
}
|
||||
|
@ -59,6 +59,9 @@ TEST_F(UserComputationTest, SimpleComputation) {
|
||||
param_request.set_name("param0");
|
||||
TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle param_handle,
|
||||
computation.AddParameterInstruction(param_request));
|
||||
OpMetadata metadata;
|
||||
metadata.set_op_name("meta");
|
||||
TF_ASSERT_OK(computation.SetOpMetadata(param_handle, metadata));
|
||||
|
||||
OutfeedRequest outfeed_request;
|
||||
*outfeed_request.mutable_operand() = constant_handle;
|
||||
@ -135,6 +138,8 @@ TEST_F(UserComputationTest, SimpleComputation) {
|
||||
// The root of the instruction should be the parameter instruction (not the
|
||||
// outfeed).
|
||||
EXPECT_THAT(hlo_computation->root_instruction(), op::Parameter());
|
||||
EXPECT_EQ(hlo_computation->root_instruction()->metadata().op_name(),
|
||||
"meta");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include <initializer_list>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/compiler/xla/array2d.h"
|
||||
@ -171,6 +172,36 @@ class LiteralTestUtil {
|
||||
tensorflow::gtl::ArraySlice<int64> minor_to_major,
|
||||
const Literal& literal);
|
||||
|
||||
// Creates a literal with the supplied shape, and uses the provided value
|
||||
// generator to populate the literal's values.
|
||||
// Returns the new literal object, or an error Status if failed.
|
||||
template <
|
||||
PrimitiveType type,
|
||||
typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
|
||||
static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
|
||||
const Shape& shape,
|
||||
const std::function<T(tensorflow::gtl::ArraySlice<int64>)>& generator);
|
||||
|
||||
// Creates a literal with the supplied shape, and initializes the literal
|
||||
// values using a normal distribution with given mean and stddev standard
|
||||
// deviation, and using the engine as entropy generator.
|
||||
// Returns the new literal object, or an error Status if failed.
|
||||
template <
|
||||
PrimitiveType type, typename E,
|
||||
typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
|
||||
static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
|
||||
const Shape& shape, E* engine, T mean, T stddev);
|
||||
|
||||
// Creates a literal with the supplied shape, and initializes the literal
|
||||
// values using a normal distribution with given mean and stddev standard
|
||||
// deviation.
|
||||
// Returns the new literal object, or an error Status if failed.
|
||||
template <
|
||||
PrimitiveType type,
|
||||
typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
|
||||
static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
|
||||
const Shape& shape, T mean, T stddev);
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(LiteralTestUtil);
|
||||
};
|
||||
@ -270,6 +301,40 @@ template <typename NativeT>
|
||||
ExpectNear(*LiteralUtil::CreateR4FromArray4D(expected), actual, error);
|
||||
}
|
||||
|
||||
template <PrimitiveType type, typename T>
|
||||
/* static */ StatusOr<std::unique_ptr<Literal>>
|
||||
LiteralTestUtil::CreateRandomLiteral(
|
||||
const Shape& shape,
|
||||
const std::function<T(tensorflow::gtl::ArraySlice<int64>)>& generator) {
|
||||
using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
|
||||
TF_RET_CHECK(shape.element_type() == type);
|
||||
std::unique_ptr<Literal> literal = LiteralUtil::CreateFromShape(shape);
|
||||
TF_RETURN_IF_ERROR(LiteralUtil::Populate<NativeT>(
|
||||
literal.get(), [&](tensorflow::gtl::ArraySlice<int64> indexes) {
|
||||
return generator(indexes);
|
||||
}));
|
||||
return std::move(literal);
|
||||
}
|
||||
|
||||
template <PrimitiveType type, typename E, typename T>
|
||||
/* static */ StatusOr<std::unique_ptr<Literal>>
|
||||
LiteralTestUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean,
|
||||
T stddev) {
|
||||
using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
|
||||
std::normal_distribution<NativeT> generator(mean, stddev);
|
||||
return CreateRandomLiteral<type, NativeT>(
|
||||
shape, [&](tensorflow::gtl::ArraySlice<int64> /*indexes*/) {
|
||||
return generator(*engine);
|
||||
});
|
||||
}
|
||||
|
||||
template <PrimitiveType type, typename T>
|
||||
/* static */ StatusOr<std::unique_ptr<Literal>>
|
||||
LiteralTestUtil::CreateRandomLiteral(const Shape& shape, T mean, T stddev) {
|
||||
std::minstd_rand0 engine;
|
||||
return CreateRandomLiteral<type>(shape, &engine, mean, stddev);
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_TESTS_LITERAL_TEST_UTIL_H_
|
||||
|
@ -741,6 +741,7 @@ class Dense(tf_core_layers.Dense, Layer):
|
||||
self.constraints[self.kernel] = self.kernel_constraint
|
||||
if self.use_bias and self.bias_constraint:
|
||||
self.constraints[self.bias] = self.bias_constraint
|
||||
self.built = True
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
|
@ -111,6 +111,7 @@ class _Merge(Layer):
|
||||
self._reshape_required = False
|
||||
else:
|
||||
self._reshape_required = True
|
||||
self.built = True
|
||||
|
||||
def call(self, inputs):
|
||||
if self._reshape_required:
|
||||
@ -302,6 +303,7 @@ class Concatenate(_Merge):
|
||||
'inputs with matching shapes '
|
||||
'except for the concat axis. '
|
||||
'Got inputs shapes: %s' % (input_shape))
|
||||
self.built = True
|
||||
|
||||
def call(self, inputs):
|
||||
if not isinstance(inputs, list):
|
||||
@ -414,6 +416,7 @@ class Dot(_Merge):
|
||||
raise ValueError('Dimension incompatibility '
|
||||
'%s != %s. ' % (shape1[axes[0]], shape2[axes[1]]) +
|
||||
'Layer shapes: %s, %s' % (shape1, shape2))
|
||||
self.built = True
|
||||
|
||||
def call(self, inputs):
|
||||
x1 = inputs[0]
|
||||
|
@ -166,6 +166,7 @@ class TimeDistributed(Wrapper):
|
||||
self.layer.build(child_input_shape)
|
||||
self.layer.built = True
|
||||
super(TimeDistributed, self).build()
|
||||
self.built = True
|
||||
|
||||
def _compute_output_shape(self, input_shape):
|
||||
input_shape = tensor_shape.TensorShape(input_shape).as_list()
|
||||
|
@ -844,7 +844,7 @@ def convolution(inputs,
|
||||
variable would be created and added the activations. Finally, if
|
||||
`activation_fn` is not `None`, it is applied to the activations as well.
|
||||
|
||||
Performs a'trous convolution with input stride/dilation rate equal to `rate`
|
||||
Performs atrous convolution with input stride/dilation rate equal to `rate`
|
||||
if a value > 1 for any dimension of `rate` is specified. In this case
|
||||
`stride` values != 1 are not supported.
|
||||
|
||||
@ -870,7 +870,7 @@ def convolution(inputs,
|
||||
"NCW". For N=2, the valid values are "NHWC" (default) and "NCHW".
|
||||
For N=3, the valid values are "NDHWC" (default) and "NCDHW".
|
||||
rate: A sequence of N positive integers specifying the dilation rate to use
|
||||
for a'trous convolution. Can be a single integer to specify the same
|
||||
for atrous convolution. Can be a single integer to specify the same
|
||||
value for all spatial dimensions. Specifying any `rate` value != 1 is
|
||||
incompatible with specifying any `stride` value != 1.
|
||||
activation_fn: Activation function. The default value is a ReLU function.
|
||||
@ -1865,7 +1865,7 @@ def separable_convolution2d(
|
||||
depthwise convolution stride. Can be an int if both strides are the same.
|
||||
padding: One of 'VALID' or 'SAME'.
|
||||
rate: A list of length 2: [rate_height, rate_width], specifying the dilation
|
||||
rates for a'trous convolution. Can be an int if both rates are the same.
|
||||
rates for atrous convolution. Can be an int if both rates are the same.
|
||||
If any value is larger than one, then both stride values need to be one.
|
||||
activation_fn: Activation function. The default value is a ReLU function.
|
||||
Explicitly set it to None to skip it and maintain a linear activation.
|
||||
|
@ -966,7 +966,8 @@ class BaseEstimator(
|
||||
saver.Saver(
|
||||
sharded=True,
|
||||
max_to_keep=self._config.keep_checkpoint_max,
|
||||
defer_build=True))
|
||||
defer_build=True,
|
||||
save_relative_paths=True))
|
||||
|
||||
chief_hooks = []
|
||||
if (self._config.save_checkpoints_secs or
|
||||
|
@ -28,6 +28,8 @@ import numpy as np
|
||||
import six
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from google.protobuf import text_format
|
||||
|
||||
from tensorflow.contrib import learn
|
||||
from tensorflow.contrib import lookup
|
||||
from tensorflow.contrib.framework.python.ops import variables
|
||||
@ -50,6 +52,7 @@ from tensorflow.python.client import session as session_lib
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -61,6 +64,7 @@ from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import loader
|
||||
from tensorflow.python.saved_model import tag_constants
|
||||
from tensorflow.python.training import basic_session_run_hooks
|
||||
from tensorflow.python.training import checkpoint_state_pb2
|
||||
from tensorflow.python.training import input as input_lib
|
||||
from tensorflow.python.training import monitored_session
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
@ -674,6 +678,38 @@ class EstimatorTest(test.TestCase):
|
||||
metrics={'MSE': metric_ops.streaming_mean_squared_error})
|
||||
self.assertLess(scores3['MSE'], scores['MSE'])
|
||||
|
||||
def test_checkpoint_contains_relative_paths(self):
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
est = estimator.Estimator(
|
||||
model_dir=tmpdir,
|
||||
model_fn=linear_model_fn_with_model_fn_ops)
|
||||
est.fit(input_fn=boston_input_fn, steps=5)
|
||||
|
||||
checkpoint_file_content = file_io.read_file_to_string(
|
||||
os.path.join(tmpdir, 'checkpoint'))
|
||||
ckpt = checkpoint_state_pb2.CheckpointState()
|
||||
text_format.Merge(checkpoint_file_content, ckpt)
|
||||
self.assertEqual(ckpt.model_checkpoint_path, 'model.ckpt-5')
|
||||
self.assertAllEqual(
|
||||
['model.ckpt-1', 'model.ckpt-5'], ckpt.all_model_checkpoint_paths)
|
||||
|
||||
def test_train_save_copy_reload(self):
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
model_dir1 = os.path.join(tmpdir, 'model_dir1')
|
||||
est1 = estimator.Estimator(
|
||||
model_dir=model_dir1,
|
||||
model_fn=linear_model_fn_with_model_fn_ops)
|
||||
est1.fit(input_fn=boston_input_fn, steps=5)
|
||||
|
||||
model_dir2 = os.path.join(tmpdir, 'model_dir2')
|
||||
os.renames(model_dir1, model_dir2)
|
||||
est2 = estimator.Estimator(
|
||||
model_dir=model_dir2,
|
||||
model_fn=linear_model_fn_with_model_fn_ops)
|
||||
self.assertEqual(5, est2.get_variable_value('global_step'))
|
||||
est2.fit(input_fn=boston_input_fn, steps=5)
|
||||
self.assertEqual(10, est2.get_variable_value('global_step'))
|
||||
|
||||
def testEstimatorParams(self):
|
||||
boston = base.load_boston()
|
||||
est = estimator.SKCompat(
|
||||
|
@ -379,7 +379,12 @@ def multi_label_head(n_classes,
|
||||
loss_fn=None):
|
||||
"""Creates a Head for multi label classification.
|
||||
|
||||
The Head uses sigmoid cross entropy loss.
|
||||
Multi-label classification handles the case where each example may have zero
|
||||
or more associated labels, from a discrete set. This is distinct from
|
||||
`multi_class_head` which has exactly one label from a discrete set.
|
||||
|
||||
This head by default uses sigmoid cross entropy loss, which expects as input
|
||||
a multi-hot tensor of shape `(batch_size, num_classes)`.
|
||||
|
||||
Args:
|
||||
n_classes: Integer, number of classes, must be >= 2
|
||||
|
@ -28,6 +28,7 @@ import six
|
||||
from tensorflow.contrib.framework.python.framework import experimental
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.estimator import run_config as core_run_config
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import server_lib
|
||||
|
||||
|
||||
@ -260,7 +261,9 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig):
|
||||
the feature.
|
||||
evaluation_master: the master on which to perform evaluation.
|
||||
model_dir: directory where model parameters, graph etc are saved. If
|
||||
`None`, see `Estimator` about where the model will be saved.
|
||||
`None`, will use `model_dir` property in `TF_CONFIG` environment
|
||||
variable. If both are set, must have same value. If both are `None`, see
|
||||
`Estimator` about where the model will be saved.
|
||||
session_config: a ConfigProto used to set session parameters, or None.
|
||||
Note - using this argument, it is easy to provide settings which break
|
||||
otherwise perfectly good models. Use with care.
|
||||
@ -291,7 +294,7 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig):
|
||||
# create Scaffold and Saver in their model_fn to set these.
|
||||
self._keep_checkpoint_max = keep_checkpoint_max
|
||||
self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours
|
||||
self._model_dir = model_dir
|
||||
self._model_dir = _get_model_dir(model_dir)
|
||||
|
||||
def replace(self, **kwargs):
|
||||
"""Returns a new instance of `RunConfig` replacing specified properties.
|
||||
@ -434,3 +437,21 @@ def _get_master(cluster_spec, task_type, task_id):
|
||||
# For backwards compatibility, we return empty string if task_type was
|
||||
# not set (task_type did not previously exist).
|
||||
return ''
|
||||
|
||||
|
||||
def _get_model_dir(model_dir):
|
||||
"""Returns `model_dir` based user provided `model_dir` or `TF_CONFIG`."""
|
||||
|
||||
model_dir_in_tf_config = json.loads(
|
||||
os.environ.get('TF_CONFIG') or '{}').get('model_dir', None)
|
||||
if model_dir_in_tf_config is not None:
|
||||
if model_dir is not None and model_dir_in_tf_config != model_dir:
|
||||
raise ValueError(
|
||||
'`model_dir` provided in RunConfig construct, if set, '
|
||||
'must have the same value as the model_dir in TF_CONFIG. '
|
||||
'model_dir: {}\nTF_CONFIG["model_dir"]: {}.\n'.format(
|
||||
model_dir, model_dir_in_tf_config))
|
||||
|
||||
logging.info('Using model_dir in TF_CONFIG: %s', model_dir_in_tf_config)
|
||||
|
||||
return model_dir or model_dir_in_tf_config
|
||||
|
@ -223,6 +223,27 @@ class RunConfigTest(test.TestCase):
|
||||
config = run_config_lib.RunConfig(model_dir=TEST_DIR)
|
||||
self.assertEqual(TEST_DIR, config.model_dir)
|
||||
|
||||
def test_model_dir_in_tf_config(self):
|
||||
tf_config = {"model_dir": TEST_DIR}
|
||||
with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}):
|
||||
run_config = run_config_lib.RunConfig()
|
||||
self.assertEqual(TEST_DIR, run_config.model_dir)
|
||||
|
||||
def test_model_dir_both_in_tf_config_and_constructor(self):
|
||||
tf_config = {"model_dir": TEST_DIR}
|
||||
with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}):
|
||||
run_config = run_config_lib.RunConfig(model_dir=TEST_DIR)
|
||||
self.assertEqual(TEST_DIR, run_config.model_dir)
|
||||
|
||||
def test_model_dir_fail_if_constructor_value_mismatch_tf_config(self):
|
||||
tf_config = {"model_dir": TEST_DIR}
|
||||
with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
"`model_dir` provided in RunConfig .* must have "
|
||||
"the same value .* in TF_CONFIG"):
|
||||
run_config_lib.RunConfig(model_dir=TEST_DIR + "/sub_dir")
|
||||
|
||||
def test_replace(self):
|
||||
config = run_config_lib.RunConfig(
|
||||
tf_random_seed=RANDOM_SEED, model_dir=TEST_DIR)
|
||||
|
@ -65,12 +65,15 @@ class SquareLinearOperatorCompositionTest(
|
||||
# feed_dict.
|
||||
matrices = sess.run(matrices)
|
||||
operator = linalg.LinearOperatorComposition(
|
||||
[linalg.LinearOperatorFullMatrix(m_ph) for m_ph in matrices_ph])
|
||||
[linalg.LinearOperatorFullMatrix(m_ph) for m_ph in matrices_ph],
|
||||
is_square=True)
|
||||
feed_dict = {m_ph: m for (m_ph, m) in zip(matrices_ph, matrices)}
|
||||
else:
|
||||
operator = linalg.LinearOperatorComposition(
|
||||
[linalg.LinearOperatorFullMatrix(m) for m in matrices])
|
||||
feed_dict = None
|
||||
# Should be auto-set.
|
||||
self.assertTrue(operator.is_square)
|
||||
|
||||
# Convert back to Tensor. Needed if use_placeholder, since then we have
|
||||
# already evaluated each matrix to a numpy array.
|
||||
|
@ -45,9 +45,10 @@ class SquareLinearOperatorFullMatrixTest(
|
||||
# values are random and we want the same value used for both mat and
|
||||
# feed_dict.
|
||||
matrix = matrix.eval()
|
||||
operator = linalg.LinearOperatorFullMatrix(matrix_ph)
|
||||
operator = linalg.LinearOperatorFullMatrix(matrix_ph, is_square=True)
|
||||
feed_dict = {matrix_ph: matrix}
|
||||
else:
|
||||
# is_square should be auto-detected here.
|
||||
operator = linalg.LinearOperatorFullMatrix(matrix)
|
||||
feed_dict = None
|
||||
|
||||
@ -68,6 +69,8 @@ class SquareLinearOperatorFullMatrixTest(
|
||||
self.assertTrue(operator.is_positive_definite)
|
||||
self.assertTrue(operator.is_non_singular)
|
||||
self.assertFalse(operator.is_self_adjoint)
|
||||
# Auto-detected.
|
||||
self.assertTrue(operator.is_square)
|
||||
|
||||
|
||||
class SquareLinearOperatorFullMatrixSymmetricPositiveDefiniteTest(
|
||||
@ -104,6 +107,7 @@ class SquareLinearOperatorFullMatrixSymmetricPositiveDefiniteTest(
|
||||
# values are random and we want the same value used for both mat and
|
||||
# feed_dict.
|
||||
matrix = matrix.eval()
|
||||
# is_square is auto-set because of self_adjoint/pd.
|
||||
operator = linalg.LinearOperatorFullMatrix(
|
||||
matrix_ph, is_self_adjoint=True, is_positive_definite=True)
|
||||
feed_dict = {matrix_ph: matrix}
|
||||
@ -129,7 +133,8 @@ class SquareLinearOperatorFullMatrixSymmetricPositiveDefiniteTest(
|
||||
|
||||
# Should be auto-set
|
||||
self.assertTrue(operator.is_non_singular)
|
||||
self.assertTrue(operator._is_spd)
|
||||
self.assertTrue(operator._can_use_cholesky)
|
||||
self.assertTrue(operator.is_square)
|
||||
|
||||
|
||||
class NonSquareLinearOperatorFullMatrixTest(
|
||||
@ -157,16 +162,14 @@ class NonSquareLinearOperatorFullMatrixTest(
|
||||
return operator, mat, feed_dict
|
||||
|
||||
def test_is_x_flags(self):
|
||||
# Matrix with two positive eigenvalues.
|
||||
matrix = [[3., 0.], [1., 1.]]
|
||||
matrix = [[3., 2., 1.], [1., 1., 1.]]
|
||||
operator = linalg.LinearOperatorFullMatrix(
|
||||
matrix,
|
||||
is_positive_definite=True,
|
||||
is_non_singular=True,
|
||||
is_self_adjoint=False)
|
||||
self.assertTrue(operator.is_positive_definite)
|
||||
self.assertTrue(operator.is_non_singular)
|
||||
self.assertEqual(operator.is_positive_definite, None)
|
||||
self.assertEqual(operator.is_non_singular, None)
|
||||
self.assertFalse(operator.is_self_adjoint)
|
||||
self.assertFalse(operator.is_square)
|
||||
|
||||
def test_matrix_must_have_at_least_two_dims_or_raises(self):
|
||||
with self.assertRaisesRegexp(ValueError, "at least 2 dimensions"):
|
||||
|
@ -54,6 +54,9 @@ class LinearOperatorShape(linalg.LinearOperator):
|
||||
def _shape_tensor(self):
|
||||
return constant_op.constant(self._stored_shape, dtype=dtypes.int32)
|
||||
|
||||
def _apply(self):
|
||||
raise NotImplementedError("Not needed for this test.")
|
||||
|
||||
|
||||
class LinearOperatorApplyOnly(linalg.LinearOperator):
|
||||
"""LinearOperator that simply wraps a [batch] matrix and implements apply."""
|
||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import abc
|
||||
import contextlib
|
||||
|
||||
from tensorflow.contrib import framework as contrib_framework
|
||||
@ -25,6 +26,7 @@ from tensorflow.contrib.linalg.python.ops import linear_operator_util
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
__all__ = ["LinearOperator"]
|
||||
|
||||
@ -50,11 +52,9 @@ class LinearOperator(object):
|
||||
|
||||
#### Performance contract
|
||||
|
||||
Subclasses should implement a method only if it can be done with a reasonable
|
||||
performance increase over generic dense operations, either in time, parallel
|
||||
scalability, or memory usage. For example, if the determinant can only be
|
||||
computed using `tf.matrix_determinant(self.to_dense())`, then determinants
|
||||
should not be implemented.
|
||||
Subclasses should only implement the assert methods
|
||||
(e.g. `assert_non_singular`) if they can be done in less than `O(N^3)`
|
||||
time.
|
||||
|
||||
Class docstrings should contain an explanation of computational complexity.
|
||||
Since this is a high-performance library, attention should be paid to detail,
|
||||
@ -100,7 +100,7 @@ class LinearOperator(object):
|
||||
operator.shape()
|
||||
==> [2, 4, 4]
|
||||
|
||||
operator.log_determinant()
|
||||
operator.log_abs_determinant()
|
||||
==> Shape [2] Tensor
|
||||
|
||||
x = ... Shape [2, 4, 5] Tensor
|
||||
@ -131,6 +131,7 @@ class LinearOperator(object):
|
||||
* If `is_X == None` (the default), callers should have no expectation either
|
||||
way.
|
||||
"""
|
||||
__metaclass__ = abc.ABCMeta
|
||||
|
||||
def __init__(self,
|
||||
dtype,
|
||||
@ -167,17 +168,23 @@ class LinearOperator(object):
|
||||
ValueError: If hints are set incorrectly.
|
||||
"""
|
||||
# Check and auto-set flags.
|
||||
if is_square is False:
|
||||
if is_non_singular or is_positive_definite:
|
||||
raise ValueError(
|
||||
"A non-singular or positive definite operator is always square.")
|
||||
self._is_square_set_by_user = is_square
|
||||
|
||||
if is_positive_definite:
|
||||
if is_non_singular is False:
|
||||
raise ValueError("A positive definite matrix is always non-singular.")
|
||||
is_non_singular = True
|
||||
|
||||
if is_non_singular:
|
||||
if is_square is False:
|
||||
raise ValueError("A non-singular matrix is always square.")
|
||||
is_square = True
|
||||
|
||||
if is_self_adjoint:
|
||||
if is_square is False:
|
||||
raise ValueError("A self-adjoint matrix is always square.")
|
||||
is_square = True
|
||||
|
||||
self._is_square_set_or_implied_by_hints = is_square
|
||||
|
||||
graph_parents = [] if graph_parents is None else graph_parents
|
||||
for i, t in enumerate(graph_parents):
|
||||
if t is None or not contrib_framework.is_tensor(t):
|
||||
@ -239,15 +246,16 @@ class LinearOperator(object):
|
||||
"""Return `True/False` depending on if this operator is square."""
|
||||
# Static checks done after __init__. Why? Because domain/range dimension
|
||||
# sometimes requires lots of work done in the derived class after init.
|
||||
static_square_check = self.domain_dimension == self.range_dimension
|
||||
if self._is_square_set_by_user is False and static_square_check:
|
||||
auto_square_check = self.domain_dimension == self.range_dimension
|
||||
if self._is_square_set_or_implied_by_hints is False and auto_square_check:
|
||||
raise ValueError(
|
||||
"User set is_square hint to False, but the operator was square.")
|
||||
if self._is_square_set_by_user is None:
|
||||
return static_square_check
|
||||
if self._is_square_set_or_implied_by_hints is None:
|
||||
return auto_square_check
|
||||
|
||||
return self._is_square_set_by_user
|
||||
return self._is_square_set_or_implied_by_hints
|
||||
|
||||
@abc.abstractmethod
|
||||
def _shape(self):
|
||||
# Write this in derived class to enable all static shape methods.
|
||||
raise NotImplementedError("_shape is not implemented.")
|
||||
@ -265,6 +273,7 @@ class LinearOperator(object):
|
||||
"""
|
||||
return self._shape()
|
||||
|
||||
@abc.abstractmethod
|
||||
def _shape_tensor(self):
|
||||
raise NotImplementedError("_shape_tensor is not implemented.")
|
||||
|
||||
@ -367,8 +376,7 @@ class LinearOperator(object):
|
||||
self._cached_tensor_rank_tensor = ops.convert_to_tensor(
|
||||
self.tensor_rank)
|
||||
else:
|
||||
self._cached_tensor_rank_tensor = array_ops.size(
|
||||
self.shape_tensor())
|
||||
self._cached_tensor_rank_tensor = array_ops.size(self.shape_tensor())
|
||||
return self._cached_tensor_rank_tensor
|
||||
|
||||
@property
|
||||
@ -486,9 +494,10 @@ class LinearOperator(object):
|
||||
"""Check that arg.dtype == self.dtype."""
|
||||
if arg.dtype != self.dtype:
|
||||
raise TypeError(
|
||||
"Expected argument to have dtype %s. Found: %s in tensor %s"
|
||||
% (self.dtype, arg.dtype, arg))
|
||||
"Expected argument to have dtype %s. Found: %s in tensor %s" %
|
||||
(self.dtype, arg.dtype, arg))
|
||||
|
||||
@abc.abstractmethod
|
||||
def _apply(self, x, adjoint=False, adjoint_arg=False):
|
||||
raise NotImplementedError("_apply is not implemented.")
|
||||
|
||||
@ -517,7 +526,9 @@ class LinearOperator(object):
|
||||
return self._apply(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
|
||||
|
||||
def _determinant(self):
|
||||
raise NotImplementedError("_det is not implemented.")
|
||||
if self._can_use_cholesky():
|
||||
return math_ops.exp(self.log_abs_determinant())
|
||||
return linalg_ops.matrix_determinant(self._matrix)
|
||||
|
||||
def determinant(self, name="det"):
|
||||
"""Determinant for every batch member.
|
||||
@ -539,7 +550,11 @@ class LinearOperator(object):
|
||||
return self._determinant()
|
||||
|
||||
def _log_abs_determinant(self):
|
||||
raise NotImplementedError("_log_abs_det is not implemented.")
|
||||
if self._can_use_cholesky():
|
||||
diag = array_ops.matrix_diag_part(self._get_cached_chol())
|
||||
return 2 * math_ops.reduce_sum(math_ops.log(diag), reduction_indices=[-1])
|
||||
abs_det = math_ops.abs(self.determinant())
|
||||
return math_ops.log(abs_det)
|
||||
|
||||
def log_abs_determinant(self, name="log_abs_det"):
|
||||
"""Log absolute value of determinant for every batch member.
|
||||
@ -561,13 +576,20 @@ class LinearOperator(object):
|
||||
return self._log_abs_determinant()
|
||||
|
||||
def _solve(self, rhs, adjoint=False, adjoint_arg=False):
|
||||
# Since this is an exact solve method for all rhs, this will only be
|
||||
# available for non-singular (batch) operators, in particular the operator
|
||||
# must be square.
|
||||
raise NotImplementedError("_solve is not implemented.")
|
||||
if self.is_square is False:
|
||||
raise NotImplementedError(
|
||||
"Solve is not yet implemented for non-square operators.")
|
||||
rhs = linear_operator_util.matrix_adjoint(rhs) if adjoint_arg else rhs
|
||||
if self._can_use_cholesky():
|
||||
return linalg_ops.cholesky_solve(self._get_cached_chol(), rhs)
|
||||
return linalg_ops.matrix_solve(
|
||||
self._get_cached_dense_matrix(), rhs, adjoint=adjoint)
|
||||
|
||||
def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"):
|
||||
"""Solve `R` (batch) systems of equations exactly: `A X = rhs`.
|
||||
"""Solve `R` (batch) systems of equations with best effort: `A X = rhs`.
|
||||
|
||||
The solution may not be exact, and in this case it will be close in some
|
||||
sense (see class docstring for details).
|
||||
|
||||
Examples:
|
||||
|
||||
@ -689,3 +711,20 @@ class LinearOperator(object):
|
||||
x = ops.convert_to_tensor(x, name="x")
|
||||
self._check_input_dtype(x)
|
||||
return self._add_to_tensor(x)
|
||||
|
||||
def _can_use_cholesky(self):
|
||||
# TODO(langmore) Add complex types when tf.cholesky can use them.
|
||||
return (not self.dtype.is_complex and self.is_self_adjoint and
|
||||
self.is_positive_definite)
|
||||
|
||||
def _get_cached_dense_matrix(self):
|
||||
if not hasattr(self, "_cached_dense_matrix"):
|
||||
self._cached_dense_matrix = self.to_dense()
|
||||
return self._cached_dense_matrix
|
||||
|
||||
def _get_cached_chol(self):
|
||||
if not self._can_use_cholesky():
|
||||
return None
|
||||
if not hasattr(self, "_cached_chol"):
|
||||
self._cached_chol = linalg_ops.cholesky(self._get_cached_dense_matrix())
|
||||
return self._cached_chol
|
||||
|
@ -63,7 +63,7 @@ class LinearOperatorComposition(linear_operator.LinearOperator):
|
||||
operator.shape
|
||||
==> [2, 2]
|
||||
|
||||
operator.log_determinant()
|
||||
operator.log_abs_determinant()
|
||||
==> scalar Tensor
|
||||
|
||||
x = ... Shape [2, 4] Tensor
|
||||
@ -96,7 +96,7 @@ class LinearOperatorComposition(linear_operator.LinearOperator):
|
||||
#### Matrix property hints
|
||||
|
||||
This `LinearOperator` is initialized with boolean flags of the form `is_X`,
|
||||
for `X = non_singular, self_adjoint, positive_definite`.
|
||||
for `X = non_singular, self_adjoint, positive_definite, square`.
|
||||
These have the following meaning
|
||||
* If `is_X == True`, callers should expect the operator to have the
|
||||
property `X`. This is a promise that should be fulfilled, but is *not* a
|
||||
@ -112,6 +112,7 @@ class LinearOperatorComposition(linear_operator.LinearOperator):
|
||||
is_non_singular=None,
|
||||
is_self_adjoint=None,
|
||||
is_positive_definite=None,
|
||||
is_square=None,
|
||||
name=None):
|
||||
r"""Initialize a `LinearOperatorComposition`.
|
||||
|
||||
@ -132,6 +133,7 @@ class LinearOperatorComposition(linear_operator.LinearOperator):
|
||||
self-adjoint to be positive-definite. See:
|
||||
https://en.wikipedia.org/wiki/Positive-definite_matrix\
|
||||
#Extension_for_non_symmetric_matrices
|
||||
is_square: Expect that this operator acts like square [batch] matrices.
|
||||
name: A name for this `LinearOperator`. Default is the individual
|
||||
operators names joined with `_o_`.
|
||||
|
||||
@ -177,6 +179,7 @@ class LinearOperatorComposition(linear_operator.LinearOperator):
|
||||
is_non_singular=is_non_singular,
|
||||
is_self_adjoint=is_self_adjoint,
|
||||
is_positive_definite=is_positive_definite,
|
||||
is_square=is_square,
|
||||
name=name)
|
||||
|
||||
@property
|
||||
|
@ -52,7 +52,7 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
|
||||
operator.shape
|
||||
==> [2, 2]
|
||||
|
||||
operator.log_determinant()
|
||||
operator.log_abs_determinant()
|
||||
==> scalar Tensor
|
||||
|
||||
x = ... Shape [2, 4] Tensor
|
||||
@ -97,7 +97,7 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
|
||||
#### Matrix property hints
|
||||
|
||||
This `LinearOperator` is initialized with boolean flags of the form `is_X`,
|
||||
for `X = non_singular, self_adjoint, positive_definite`.
|
||||
for `X = non_singular, self_adjoint, positive_definite, square`.
|
||||
These have the following meaning
|
||||
* If `is_X == True`, callers should expect the operator to have the
|
||||
property `X`. This is a promise that should be fulfilled, but is *not* a
|
||||
@ -113,6 +113,7 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
|
||||
is_non_singular=None,
|
||||
is_self_adjoint=None,
|
||||
is_positive_definite=None,
|
||||
is_square=None,
|
||||
name="LinearOperatorDiag"):
|
||||
r"""Initialize a `LinearOperatorDiag`.
|
||||
|
||||
@ -129,6 +130,7 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
|
||||
self-adjoint to be positive-definite. See:
|
||||
https://en.wikipedia.org/wiki/Positive-definite_matrix\
|
||||
#Extension_for_non_symmetric_matrices
|
||||
is_square: Expect that this operator acts like square [batch] matrices.
|
||||
name: A name for this `LinearOperator`.
|
||||
|
||||
Raises:
|
||||
@ -147,12 +149,17 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
|
||||
else:
|
||||
is_self_adjoint = True
|
||||
|
||||
if is_square is False:
|
||||
raise ValueError("Only square diagonal operators currently supported.")
|
||||
is_square = True
|
||||
|
||||
super(LinearOperatorDiag, self).__init__(
|
||||
dtype=self._diag.dtype,
|
||||
graph_parents=[self._diag],
|
||||
is_non_singular=is_non_singular,
|
||||
is_self_adjoint=is_self_adjoint,
|
||||
is_positive_definite=is_positive_definite,
|
||||
is_square=is_square,
|
||||
name=name)
|
||||
|
||||
def _check_diag(self, diag):
|
||||
|
@ -19,11 +19,9 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.linalg.python.ops import linear_operator
|
||||
from tensorflow.contrib.linalg.python.ops import linear_operator_util
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
__all__ = ["LinearOperatorFullMatrix"]
|
||||
@ -49,7 +47,7 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator):
|
||||
operator.shape
|
||||
==> [2, 2]
|
||||
|
||||
operator.log_determinant()
|
||||
operator.log_abs_determinant()
|
||||
==> scalar Tensor
|
||||
|
||||
x = ... Shape [2, 4] Tensor
|
||||
@ -93,7 +91,7 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator):
|
||||
#### Matrix property hints
|
||||
|
||||
This `LinearOperator` is initialized with boolean flags of the form `is_X`,
|
||||
for `X = non_singular, self_adjoint, positive_definite`.
|
||||
for `X = non_singular, self_adjoint, positive_definite, square`.
|
||||
These have the following meaning
|
||||
* If `is_X == True`, callers should expect the operator to have the
|
||||
property `X`. This is a promise that should be fulfilled, but is *not* a
|
||||
@ -109,6 +107,7 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator):
|
||||
is_non_singular=None,
|
||||
is_self_adjoint=None,
|
||||
is_positive_definite=None,
|
||||
is_square=None,
|
||||
name="LinearOperatorFullMatrix"):
|
||||
r"""Initialize a `LinearOperatorFullMatrix`.
|
||||
|
||||
@ -124,6 +123,7 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator):
|
||||
self-adjoint to be positive-definite. See:
|
||||
https://en.wikipedia.org/wiki/Positive-definite_matrix\
|
||||
#Extension_for_non_symmetric_matrices
|
||||
is_square: Expect that this operator acts like square [batch] matrices.
|
||||
name: A name for this `LinearOperator`.
|
||||
|
||||
Raises:
|
||||
@ -134,19 +134,13 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator):
|
||||
self._matrix = ops.convert_to_tensor(matrix, name="matrix")
|
||||
self._check_matrix(self._matrix)
|
||||
|
||||
# Special treatment for (real) Symmetric Positive Definite.
|
||||
self._is_spd = (
|
||||
(not self._matrix.dtype.is_complex)
|
||||
and is_self_adjoint and is_positive_definite)
|
||||
if self._is_spd:
|
||||
self._chol = linalg_ops.cholesky(self._matrix)
|
||||
|
||||
super(LinearOperatorFullMatrix, self).__init__(
|
||||
dtype=self._matrix.dtype,
|
||||
graph_parents=[self._matrix],
|
||||
is_non_singular=is_non_singular,
|
||||
is_self_adjoint=is_self_adjoint,
|
||||
is_positive_definite=is_positive_definite,
|
||||
is_square=is_square,
|
||||
name=name)
|
||||
|
||||
def _check_matrix(self, matrix):
|
||||
@ -177,23 +171,5 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator):
|
||||
return math_ops.matmul(
|
||||
self._matrix, x, adjoint_a=adjoint, adjoint_b=adjoint_arg)
|
||||
|
||||
def _determinant(self):
|
||||
if self._is_spd:
|
||||
return math_ops.exp(self.log_abs_determinant())
|
||||
return linalg_ops.matrix_determinant(self._matrix)
|
||||
|
||||
def _log_abs_determinant(self):
|
||||
if self._is_spd:
|
||||
diag = array_ops.matrix_diag_part(self._chol)
|
||||
return 2 * math_ops.reduce_sum(math_ops.log(diag), reduction_indices=[-1])
|
||||
abs_det = math_ops.abs(self.determinant())
|
||||
return math_ops.log(abs_det)
|
||||
|
||||
def _solve(self, rhs, adjoint=False, adjoint_arg=False):
|
||||
rhs = linear_operator_util.matrix_adjoint(rhs) if adjoint_arg else rhs
|
||||
if self._is_spd:
|
||||
return linalg_ops.cholesky_solve(self._chol, rhs)
|
||||
return linalg_ops.matrix_solve(self._matrix, rhs, adjoint=adjoint)
|
||||
|
||||
def _to_dense(self):
|
||||
return self._matrix
|
||||
|
@ -112,7 +112,7 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity):
|
||||
operator.shape
|
||||
==> [2, 2]
|
||||
|
||||
operator.log_determinant()
|
||||
operator.log_abs_determinant()
|
||||
==> 0.
|
||||
|
||||
x = ... Shape [2, 4] Tensor
|
||||
@ -180,7 +180,7 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity):
|
||||
#### Matrix property hints
|
||||
|
||||
This `LinearOperator` is initialized with boolean flags of the form `is_X`,
|
||||
for `X = non_singular, self_adjoint, positive_definite`.
|
||||
for `X = non_singular, self_adjoint, positive_definite, square`.
|
||||
These have the following meaning
|
||||
* If `is_X == True`, callers should expect the operator to have the
|
||||
property `X`. This is a promise that should be fulfilled, but is *not* a
|
||||
@ -198,6 +198,7 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity):
|
||||
is_non_singular=True,
|
||||
is_self_adjoint=True,
|
||||
is_positive_definite=True,
|
||||
is_square=True,
|
||||
assert_proper_shapes=False,
|
||||
name="LinearOperatorIdentity"):
|
||||
r"""Initialize a `LinearOperatorIdentity`.
|
||||
@ -224,6 +225,7 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity):
|
||||
self-adjoint to be positive-definite. See:
|
||||
https://en.wikipedia.org/wiki/Positive-definite_matrix\
|
||||
#Extension_for_non_symmetric_matrices
|
||||
is_square: Expect that this operator acts like square [batch] matrices.
|
||||
assert_proper_shapes: Python `bool`. If `False`, only perform static
|
||||
checks that initialization and method arguments have proper shape.
|
||||
If `True`, and static checks are inconclusive, add asserts to the graph.
|
||||
@ -248,12 +250,15 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity):
|
||||
raise ValueError("An identity operator is always non-singular.")
|
||||
if not is_positive_definite:
|
||||
raise ValueError("An identity operator is always positive-definite.")
|
||||
if not is_square:
|
||||
raise ValueError("An identity operator is always square.")
|
||||
|
||||
super(LinearOperatorIdentity, self).__init__(
|
||||
dtype=dtype,
|
||||
is_non_singular=is_non_singular,
|
||||
is_self_adjoint=is_self_adjoint,
|
||||
is_positive_definite=is_positive_definite,
|
||||
is_square=is_square,
|
||||
name=name)
|
||||
|
||||
self._num_rows = linear_operator_util.shape_tensor(
|
||||
@ -459,7 +464,7 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity):
|
||||
operator.shape
|
||||
==> [2, 2]
|
||||
|
||||
operator.log_determinant()
|
||||
operator.log_abs_determinant()
|
||||
==> 2 * Log[3]
|
||||
|
||||
x = ... Shape [2, 4] Tensor
|
||||
@ -510,7 +515,7 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity):
|
||||
#### Matrix property hints
|
||||
|
||||
This `LinearOperator` is initialized with boolean flags of the form `is_X`,
|
||||
for `X = non_singular, self_adjoint, positive_definite`.
|
||||
for `X = non_singular, self_adjoint, positive_definite, square`.
|
||||
These have the following meaning
|
||||
* If `is_X == True`, callers should expect the operator to have the
|
||||
property `X`. This is a promise that should be fulfilled, but is *not* a
|
||||
@ -527,6 +532,7 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity):
|
||||
is_non_singular=None,
|
||||
is_self_adjoint=None,
|
||||
is_positive_definite=None,
|
||||
is_square=True,
|
||||
assert_proper_shapes=False,
|
||||
name="LinearOperatorScaledIdentity"):
|
||||
r"""Initialize a `LinearOperatorScaledIdentity`.
|
||||
@ -550,6 +556,7 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity):
|
||||
self-adjoint to be positive-definite. See:
|
||||
https://en.wikipedia.org/wiki/Positive-definite_matrix\
|
||||
#Extension_for_non_symmetric_matrices
|
||||
is_square: Expect that this operator acts like square [batch] matrices.
|
||||
assert_proper_shapes: Python `bool`. If `False`, only perform static
|
||||
checks that initialization and method arguments have proper shape.
|
||||
If `True`, and static checks are inconclusive, add asserts to the graph.
|
||||
@ -561,6 +568,9 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity):
|
||||
"""
|
||||
self._assert_proper_shapes = assert_proper_shapes
|
||||
|
||||
if not is_square:
|
||||
raise ValueError("A ScaledIdentity operator is always square.")
|
||||
|
||||
with ops.name_scope(name, values=[multiplier, num_rows]):
|
||||
self._multiplier = ops.convert_to_tensor(multiplier, name="multiplier")
|
||||
|
||||
@ -569,6 +579,7 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity):
|
||||
is_non_singular=is_non_singular,
|
||||
is_self_adjoint=is_self_adjoint,
|
||||
is_positive_definite=is_positive_definite,
|
||||
is_square=is_square,
|
||||
name=name)
|
||||
|
||||
# Shape [B1,...Bb, 1, 1]
|
||||
|
@ -53,7 +53,7 @@ class LinearOperatorTriL(linear_operator.LinearOperator):
|
||||
operator.shape
|
||||
==> [2, 2]
|
||||
|
||||
operator.log_determinant()
|
||||
operator.log_abs_determinant()
|
||||
==> scalar Tensor
|
||||
|
||||
x = ... Shape [2, 4] Tensor
|
||||
@ -90,7 +90,7 @@ class LinearOperatorTriL(linear_operator.LinearOperator):
|
||||
#### Matrix property hints
|
||||
|
||||
This `LinearOperator` is initialized with boolean flags of the form `is_X`,
|
||||
for `X = non_singular, self_adjoint, positive_definite`.
|
||||
for `X = non_singular, self_adjoint, positive_definite, square`.
|
||||
These have the following meaning
|
||||
* If `is_X == True`, callers should expect the operator to have the
|
||||
property `X`. This is a promise that should be fulfilled, but is *not* a
|
||||
@ -106,6 +106,7 @@ class LinearOperatorTriL(linear_operator.LinearOperator):
|
||||
is_non_singular=None,
|
||||
is_self_adjoint=None,
|
||||
is_positive_definite=None,
|
||||
is_square=None,
|
||||
name="LinearOperatorTriL"):
|
||||
r"""Initialize a `LinearOperatorTriL`.
|
||||
|
||||
@ -126,12 +127,19 @@ class LinearOperatorTriL(linear_operator.LinearOperator):
|
||||
self-adjoint to be positive-definite. See:
|
||||
https://en.wikipedia.org/wiki/Positive-definite_matrix\
|
||||
#Extension_for_non_symmetric_matrices
|
||||
is_square: Expect that this operator acts like square [batch] matrices.
|
||||
name: A name for this `LinearOperator`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `diag.dtype` is not an allowed type.
|
||||
ValueError: If `is_square` is `False`.
|
||||
"""
|
||||
|
||||
if is_square is False:
|
||||
raise ValueError(
|
||||
"Only square lower triangular operators supported at this time.")
|
||||
is_square = True
|
||||
|
||||
with ops.name_scope(name, values=[tril]):
|
||||
self._tril = ops.convert_to_tensor(tril, name="tril")
|
||||
self._check_tril(self._tril)
|
||||
@ -144,6 +152,7 @@ class LinearOperatorTriL(linear_operator.LinearOperator):
|
||||
is_non_singular=is_non_singular,
|
||||
is_self_adjoint=is_self_adjoint,
|
||||
is_positive_definite=is_positive_definite,
|
||||
is_square=is_square,
|
||||
name=name)
|
||||
|
||||
def _check_tril(self, tril):
|
||||
|
@ -2417,6 +2417,9 @@ tf_cc_test(
|
||||
":test_main",
|
||||
":testlib",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:cc_ops_internal",
|
||||
"//tensorflow/cc:function_ops",
|
||||
"//tensorflow/cc:functional_ops",
|
||||
"//tensorflow/core/kernels:cast_op",
|
||||
"//tensorflow/core/kernels:cwise_op",
|
||||
"//tensorflow/core/kernels:function_ops",
|
||||
|
@ -1001,25 +1001,19 @@ string NewName(const Node* n, bool pretty) {
|
||||
void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty) {
|
||||
// We visit nodes in forward topological sort order, which is a
|
||||
// possible execution order of the graph.
|
||||
std::vector<size_t> pending(g->num_node_ids());
|
||||
std::deque<const Node*> ready;
|
||||
for (const Node* n : g->nodes()) {
|
||||
pending[n->id()] = n->in_edges().size();
|
||||
if (pending[n->id()] == 0) ready.push_back(n);
|
||||
}
|
||||
gtl::InlinedVector<const Edge*, 4> inputs;
|
||||
gdef->Clear();
|
||||
gdef->mutable_versions()->CopyFrom(g->versions());
|
||||
while (!ready.empty()) {
|
||||
const Node* n = ready.front();
|
||||
ready.pop_front();
|
||||
for (const Edge* e : n->out_edges()) {
|
||||
const Node* next = e->dst();
|
||||
if (--pending[next->id()] == 0) {
|
||||
ready.push_back(next);
|
||||
|
||||
std::vector<Node*> start_nodes;
|
||||
for (Node* n : g->nodes()) {
|
||||
if (n->out_edges().empty()) {
|
||||
start_nodes.push_back(n);
|
||||
}
|
||||
}
|
||||
if (!n->IsOp()) continue;
|
||||
|
||||
ReverseDFSFrom(*g, start_nodes, nullptr, [gdef, pretty, &inputs](Node* n) {
|
||||
if (!n->IsOp()) return;
|
||||
NodeDef* ndef = gdef->add_node();
|
||||
ndef->set_name(NewName(n, pretty));
|
||||
ndef->set_op(n->type_string());
|
||||
@ -1054,7 +1048,7 @@ void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty) {
|
||||
ndef->add_input(strings::StrCat(srcname, ":", e->src_output()));
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
string DebugString(const Graph* g) {
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -163,7 +163,7 @@ Status ShapeRefiner::UpdateNode(const Node* node, bool* refined) {
|
||||
|
||||
InferenceContext* c = iter->second.get();
|
||||
DCHECK_GE(e->dst_input(), 0);
|
||||
if (node_context->set_input(e->dst_input(), c->output(e->src_output()))) {
|
||||
if (node_context->MergeInput(e->dst_input(), c->output(e->src_output()))) {
|
||||
*refined = true;
|
||||
}
|
||||
|
||||
@ -174,7 +174,7 @@ Status ShapeRefiner::UpdateNode(const Node* node, bool* refined) {
|
||||
e->dst_input(), c->output_handle_dtype(e->src_output()))) {
|
||||
*refined = true;
|
||||
}
|
||||
if (node_context->set_input_handle_shape(
|
||||
if (node_context->MergeInputHandleShape(
|
||||
e->dst_input(), c->output_handle_shape(e->src_output()))) {
|
||||
*refined = true;
|
||||
}
|
||||
|
@ -400,16 +400,33 @@ void SetAttrValue(gtl::ArraySlice<NameAttrList> value, AttrValue* out) {
|
||||
}
|
||||
}
|
||||
|
||||
// Wrapper around protocol buffer serialization that requests deterministic
|
||||
// serialization, in particular for Map fields, which serialize in a random
|
||||
// order by default. Returns true on success.
|
||||
template <typename T>
|
||||
static bool DeterministicSerialization(const T& t, string* result) {
|
||||
const int size = t.ByteSize();
|
||||
*result = string(size, '\0');
|
||||
::tensorflow::protobuf::io::ArrayOutputStream array_stream(&(*result)[0],
|
||||
size);
|
||||
::tensorflow::protobuf::io::CodedOutputStream output_stream(&array_stream);
|
||||
output_stream.SetSerializationDeterministic(true);
|
||||
t.SerializeWithCachedSizes(&output_stream);
|
||||
return !output_stream.HadError() && size == output_stream.ByteCount();
|
||||
}
|
||||
|
||||
bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b) {
|
||||
string a_str, b_str;
|
||||
a.SerializeToString(&a_str);
|
||||
b.SerializeToString(&b_str);
|
||||
DeterministicSerialization(a, &a_str);
|
||||
DeterministicSerialization(b, &b_str);
|
||||
// Note: it should be safe to compare proto serializations of the attr
|
||||
// values since at most one field should be set in each (indeed, it
|
||||
// must be the same field if they are to compare equal).
|
||||
// Exception: there are multiple equivalent representations of
|
||||
// TensorProtos. So a return value of true implies a == b, but not the
|
||||
// converse.
|
||||
// TODO(phawkins): this is incorrect for NameAttrList attributes that may
|
||||
// contain nested AttrValue maps.
|
||||
return a_str == b_str;
|
||||
}
|
||||
|
||||
|
@ -191,16 +191,18 @@ class InferenceContext {
|
||||
return s;
|
||||
}
|
||||
|
||||
// Set the shape of the input in position idx. This requires idx to be in the
|
||||
// [0, num_inputs) range. Returns true iff the stored input shape has been
|
||||
// updated with a different handle.
|
||||
bool set_input(int idx, ShapeHandle shape) {
|
||||
if (!inputs_[idx].SameHandle(shape)) {
|
||||
inputs_[idx] = shape;
|
||||
return true;
|
||||
} else {
|
||||
// Merge the stored shape of the input in position idx with the specified
|
||||
// shape. This requires idx to be in the [0, num_inputs) range. If the merge
|
||||
// is successful and the new shape differs from the old one, store the new
|
||||
// shape and return true. Return false otherwise.
|
||||
bool MergeInput(int idx, ShapeHandle shape) {
|
||||
ShapeHandle new_shape;
|
||||
if (!Merge(inputs_[idx], shape, &new_shape).ok() ||
|
||||
inputs_[idx].SameHandle(new_shape)) {
|
||||
return false;
|
||||
}
|
||||
inputs_[idx] = new_shape;
|
||||
return true;
|
||||
}
|
||||
ShapeHandle input(int64 idx) const { return inputs_[idx]; }
|
||||
Status input(StringPiece input_name, std::vector<ShapeHandle>* output) const;
|
||||
@ -442,16 +444,19 @@ class InferenceContext {
|
||||
// propagate that information. Output handle dtypes and shapes are ignored if
|
||||
// the output tensor is not of type DT_RESOURCE.
|
||||
|
||||
// Set the shape corresponding to the resource in position idx. This requires
|
||||
// idx to be in the [0, num_inputs) range. Returns true iff the stored shape
|
||||
// has been updated with a different handle.
|
||||
bool set_input_handle_shape(int idx, ShapeHandle shape) {
|
||||
if (!input_handle_shape_[idx].SameHandle(shape)) {
|
||||
// Merge the stored shape corresponding to the input handle in position idx
|
||||
// with the specified shape. This requires idx to be in the [0, num_inputs)
|
||||
// range. If the merge is successful and the new shape differs from the old
|
||||
// one, store the new shape and return true. Return false otherwise.
|
||||
bool MergeInputHandleShape(int idx, ShapeHandle shape) {
|
||||
ShapeHandle new_shape;
|
||||
if (!Merge(input_handle_shape_[idx], shape, &new_shape).ok() ||
|
||||
input_handle_shape_[idx].SameHandle(new_shape)) {
|
||||
return false;
|
||||
}
|
||||
input_handle_shape_[idx] = shape;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Set the type corresponding to the resource in position idx. This requires
|
||||
// idx to be in the [0, num_inputs) range. Returns true iff the stored type
|
||||
@ -468,15 +473,24 @@ class InferenceContext {
|
||||
return input_handle_dtype_[idx];
|
||||
}
|
||||
|
||||
// Set the shape corresponding to the resource in position idx. This requires
|
||||
// idx to be in the [0, num_outputs) range.
|
||||
// Returns true iff the stored shape has been updated with a different handle.
|
||||
bool set_output_handle_shape(int idx, ShapeHandle shape) {
|
||||
if (!output_handle_shape_[idx].SameHandle(shape)) {
|
||||
// Merge the stored shape corresponding to the output handle in position idx
|
||||
// with the specified shape. This requires idx to be in the [0, num_outputs)
|
||||
// range. If the merge is successful and the new shape differs from the old
|
||||
// one, store the new shape and return true. Return false otherwise.
|
||||
|
||||
bool MergeOutputHandleShape(int idx, ShapeHandle shape) {
|
||||
ShapeHandle new_shape;
|
||||
if (!Merge(output_handle_shape_[idx], shape, &new_shape).ok() ||
|
||||
output_handle_shape_[idx].SameHandle(new_shape)) {
|
||||
return false;
|
||||
}
|
||||
output_handle_shape_[idx] = shape;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
// Overwrite the shape corresponding to the output handle in position idx with
|
||||
// the specified shape.
|
||||
void set_output_handle_shape(int idx, ShapeHandle shape) {
|
||||
output_handle_shape_[idx] = shape;
|
||||
}
|
||||
|
||||
// Set the type corresponding to the resource in position idx. This requires
|
||||
|
@ -23,8 +23,8 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
void DFS(const Graph& g, std::function<void(Node*)> enter,
|
||||
std::function<void(Node*)> leave) {
|
||||
void DFS(const Graph& g, const std::function<void(Node*)>& enter,
|
||||
const std::function<void(Node*)>& leave) {
|
||||
// Stack of work to do.
|
||||
struct Work {
|
||||
Node* node;
|
||||
@ -61,15 +61,23 @@ void DFS(const Graph& g, std::function<void(Node*)> enter,
|
||||
}
|
||||
}
|
||||
|
||||
void ReverseDFS(const Graph& g, std::function<void(Node*)> enter,
|
||||
std::function<void(Node*)> leave) {
|
||||
void ReverseDFS(const Graph& g, const std::function<void(Node*)>& enter,
|
||||
const std::function<void(Node*)>& leave) {
|
||||
ReverseDFSFrom(g, {g.sink_node()}, enter, leave);
|
||||
}
|
||||
|
||||
void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<Node*> start,
|
||||
const std::function<void(Node*)>& enter,
|
||||
const std::function<void(Node*)>& leave) {
|
||||
// Stack of work to do.
|
||||
struct Work {
|
||||
Node* node;
|
||||
bool leave; // Are we entering or leaving n?
|
||||
};
|
||||
std::vector<Work> stack;
|
||||
stack.push_back(Work{g.sink_node(), false});
|
||||
std::vector<Work> stack(start.size());
|
||||
for (int i = 0; i < start.size(); ++i) {
|
||||
stack[i] = Work{start[i], false};
|
||||
}
|
||||
|
||||
std::vector<bool> visited(g.num_node_ids(), false);
|
||||
while (!stack.empty()) {
|
||||
|
@ -21,20 +21,28 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Perform a depth-first-search on g starting at the source node.
|
||||
// If enter is not empty, calls enter(n) before visiting any children of n.
|
||||
// If leave is not empty, calls leave(n) after visiting all children of n.
|
||||
extern void DFS(const Graph& g, std::function<void(Node*)> enter,
|
||||
std::function<void(Node*)> leave);
|
||||
extern void DFS(const Graph& g, const std::function<void(Node*)>& enter,
|
||||
const std::function<void(Node*)>& leave);
|
||||
|
||||
// Perform a reverse depth-first-search on g starting at the sink node.
|
||||
// If enter is not empty, calls enter(n) before visiting any parents of n.
|
||||
// If leave is not empty, calls leave(n) after visiting all parents of n.
|
||||
extern void ReverseDFS(const Graph& g, std::function<void(Node*)> enter,
|
||||
std::function<void(Node*)> leave);
|
||||
extern void ReverseDFS(const Graph& g, const std::function<void(Node*)>& enter,
|
||||
const std::function<void(Node*)>& leave);
|
||||
|
||||
// Perform a reverse depth-first-search on g starting at the 'start' nodes.
|
||||
// If enter is not empty, calls enter(n) before visiting any parents of n.
|
||||
// If leave is not empty, calls leave(n) after visiting all parents of n.
|
||||
extern void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<Node*> start,
|
||||
const std::function<void(Node*)>& enter,
|
||||
const std::function<void(Node*)>& leave);
|
||||
|
||||
// Stores in *order the post-order numbering of all nodes
|
||||
// in graph found via a depth first search starting at the source node.
|
||||
|
@ -90,6 +90,23 @@ cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "robust_stats",
|
||||
srcs = ["robust_stats.cc"],
|
||||
hdrs = ["robust_stats.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "robust_stats_test",
|
||||
srcs = ["robust_stats_test.cc"],
|
||||
deps = [
|
||||
":robust_stats",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "utils",
|
||||
srcs = ["utils.cc"],
|
||||
@ -116,3 +133,37 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "virtual_scheduler",
|
||||
srcs = ["virtual_scheduler.cc"],
|
||||
hdrs = ["virtual_scheduler.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core/grappler/costs:cost_estimator",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "measuring_cost_estimator",
|
||||
srcs = ["measuring_cost_estimator.cc"],
|
||||
hdrs = ["measuring_cost_estimator.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":robust_stats",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:grappler_item_builder",
|
||||
"//tensorflow/core/grappler/clusters:cluster",
|
||||
"//tensorflow/core/grappler/costs:cost_estimator",
|
||||
"//tensorflow/core/kernels:ops_util",
|
||||
],
|
||||
)
|
||||
|
@ -84,8 +84,8 @@ Status GraphProperties::InferStatically() {
|
||||
}
|
||||
}
|
||||
}
|
||||
if (qctx->set_output_handle_dtype(0, queue_type) ||
|
||||
qctx->set_output_handle_shape(0, queue_shp)) {
|
||||
if (qctx->set_output_handle_dtype(0, queue_type) |
|
||||
qctx->MergeOutputHandleShape(0, queue_shp)) {
|
||||
new_shapes.push(qnode);
|
||||
}
|
||||
}
|
||||
|
@ -177,10 +177,14 @@ TEST_F(GraphPropertiesTest, Queues) {
|
||||
auto dequeue2 =
|
||||
ops::QueueDequeue(root.WithOpName("Dequeue2"), q2, {DataType::DT_FLOAT});
|
||||
|
||||
// Create a queue that feeds itself.
|
||||
auto q3 =
|
||||
ops::RandomShuffleQueue(root.WithOpName("Queue3"), {DataType::DT_FLOAT});
|
||||
auto dequeue3 =
|
||||
ops::QueueDequeue(root.WithOpName("Dequeue3"), q3, {DataType::DT_FLOAT});
|
||||
auto merge3 = ops::Merge(root.WithOpName("Merge3"), {dequeue3[0], square2});
|
||||
auto enqueue3 =
|
||||
ops::QueueEnqueue(root.WithOpName("Enqueue3"), q3, {merge3.output});
|
||||
|
||||
auto q4 =
|
||||
ops::RandomShuffleQueue(root.WithOpName("Queue4"), {DataType::DT_FLOAT});
|
||||
@ -227,6 +231,229 @@ TEST_F(GraphPropertiesTest, Queues) {
|
||||
EXPECT_EQ(7, prop4.shape().dim(1).size());
|
||||
}
|
||||
|
||||
TEST_F(GraphPropertiesTest, Loops) {
|
||||
// Test graph produced in python using:
|
||||
/*
|
||||
with tf.Graph().as_default():
|
||||
i = tf.constant(0)
|
||||
c = lambda i: tf.less(i, 10)
|
||||
b = lambda i: tf.add(i, 1)
|
||||
r = tf.while_loop(c, b, [i])
|
||||
with open('/tmp/graph.txt', 'w') as f:
|
||||
f.write(str(tf.get_default_graph().as_graph_def()))
|
||||
*/
|
||||
const string gdef_ascii = R"EOF(
|
||||
node {
|
||||
name: "Const"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT32
|
||||
tensor_shape {
|
||||
}
|
||||
int_val: 0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "while/Enter"
|
||||
op: "Enter"
|
||||
input: "Const"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "frame_name"
|
||||
value {
|
||||
s: "while/while/"
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "is_constant"
|
||||
value {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "parallel_iterations"
|
||||
value {
|
||||
i: 10
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "while/Merge"
|
||||
op: "Merge"
|
||||
input: "while/Enter"
|
||||
input: "while/NextIteration"
|
||||
attr {
|
||||
key: "N"
|
||||
value {
|
||||
i: 2
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "while/Less/y"
|
||||
op: "Const"
|
||||
input: "^while/Merge"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT32
|
||||
tensor_shape {
|
||||
}
|
||||
int_val: 10
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "while/Less"
|
||||
op: "Less"
|
||||
input: "while/Merge"
|
||||
input: "while/Less/y"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "while/LoopCond"
|
||||
op: "LoopCond"
|
||||
input: "while/Less"
|
||||
}
|
||||
node {
|
||||
name: "while/Switch"
|
||||
op: "Switch"
|
||||
input: "while/Merge"
|
||||
input: "while/LoopCond"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@while/Merge"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "while/Identity"
|
||||
op: "Identity"
|
||||
input: "while/Switch:1"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "while/Add/y"
|
||||
op: "Const"
|
||||
input: "^while/Identity"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT32
|
||||
tensor_shape {
|
||||
}
|
||||
int_val: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "while/Add"
|
||||
op: "Add"
|
||||
input: "while/Identity"
|
||||
input: "while/Add/y"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "while/NextIteration"
|
||||
op: "NextIteration"
|
||||
input: "while/Add"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "while/Exit"
|
||||
op: "Exit"
|
||||
input: "while/Switch"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
versions {
|
||||
producer: 11
|
||||
}
|
||||
)EOF";
|
||||
|
||||
GrapplerItem item;
|
||||
CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &item.graph));
|
||||
GraphProperties properties(item);
|
||||
TF_CHECK_OK(properties.InferStatically());
|
||||
|
||||
const auto props = properties.GetOutputProperties("while/Exit");
|
||||
EXPECT_EQ(1, props.size());
|
||||
const OpInfo::TensorProperties& prop = props[0];
|
||||
EXPECT_EQ(DT_INT32, prop.dtype());
|
||||
EXPECT_TRUE(prop.shape().unknown_rank());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
133
tensorflow/core/grappler/costs/measuring_cost_estimator.cc
Normal file
133
tensorflow/core/grappler/costs/measuring_cost_estimator.cc
Normal file
@ -0,0 +1,133 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/costs/measuring_cost_estimator.h"
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include "tensorflow/core/grappler/clusters/cluster.h"
|
||||
#include "tensorflow/core/grappler/costs/robust_stats.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/lib/core/blocking_counter.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
MeasuringCostEstimator::MeasuringCostEstimator(Cluster* cluster,
|
||||
int measurement_steps,
|
||||
int measurement_threads)
|
||||
: measurement_steps_(measurement_steps),
|
||||
measurement_threads_(measurement_threads) {
|
||||
CHECK_GE(measurement_steps, 1);
|
||||
if (measurement_threads > 0) {
|
||||
thread_pool_.reset(new thread::ThreadPool(
|
||||
Env::Default(), SanitizeThreadSuffix("measurements"),
|
||||
measurement_threads));
|
||||
}
|
||||
cluster_ = cluster;
|
||||
}
|
||||
|
||||
Status MeasuringCostEstimator::Initialize(const GrapplerItem& item) {
|
||||
feed_ = item.feed;
|
||||
fetch_ = item.fetch;
|
||||
return cluster_->Initialize(item);
|
||||
}
|
||||
|
||||
Status MeasuringCostEstimator::PredictCosts(const GraphDef& optimized_graph,
|
||||
CostGraphDef* cost_graph,
|
||||
Costs* costs) const {
|
||||
std::vector<double> times(measurement_steps_);
|
||||
BlockingCounter barrier(measurement_steps_);
|
||||
|
||||
mutex status_mu;
|
||||
Status status;
|
||||
|
||||
auto measurement_fn = [&](const int step) {
|
||||
const Costs::MicroSeconds start = Env::Default()->NowMicros();
|
||||
|
||||
RunMetadata metadata;
|
||||
const Status local_status =
|
||||
cluster_->Run(optimized_graph, feed_, fetch_, &metadata);
|
||||
{
|
||||
mutex_lock lock(status_mu);
|
||||
status.Update(local_status);
|
||||
}
|
||||
if (step < 0) {
|
||||
// Discard the first iteration as it triggers the warmup, and therefore
|
||||
// takes much longer than a normal step.
|
||||
return;
|
||||
}
|
||||
if (!local_status.ok()) {
|
||||
// Discard the data if the run wasn't sucessful.
|
||||
barrier.DecrementCount();
|
||||
return;
|
||||
}
|
||||
|
||||
const Costs::MicroSeconds finish = Env::Default()->NowMicros();
|
||||
const double time = (finish - start).count() * 1e3;
|
||||
times[step] = time;
|
||||
|
||||
if (cost_graph && (step + 1 == measurement_steps_)) {
|
||||
metadata.mutable_cost_graph()->Swap(cost_graph);
|
||||
}
|
||||
|
||||
barrier.DecrementCount();
|
||||
};
|
||||
|
||||
// Initialize the computation and warm up TensorFlow.
|
||||
measurement_fn(-1);
|
||||
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Failed to run start measurements: "
|
||||
<< status.error_message();
|
||||
costs->execution_time = Costs::Duration::max();
|
||||
return status;
|
||||
}
|
||||
|
||||
// Run "measurement_steps_" and measure the time.
|
||||
if (measurement_threads_ > 0) {
|
||||
for (int i = 0; i < measurement_steps_; ++i) {
|
||||
thread_pool_->Schedule([i, &measurement_fn]() { measurement_fn(i); });
|
||||
}
|
||||
barrier.Wait();
|
||||
} else {
|
||||
for (int i = 0; i < measurement_steps_ && status.ok(); ++i) {
|
||||
measurement_fn(i);
|
||||
}
|
||||
}
|
||||
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Failed to measure graph performance: "
|
||||
<< status.error_message();
|
||||
costs->execution_time = Costs::Duration::max();
|
||||
costs->max_execution_time = Costs::Duration::max();
|
||||
costs->min_execution_time = 0;
|
||||
return status;
|
||||
}
|
||||
|
||||
// Compute the average time of the measure steps. Use Huber statistics
|
||||
// to filter out outliers.
|
||||
RobustStats stats(times);
|
||||
costs->execution_time = Costs::Duration(stats.mean());
|
||||
costs->max_execution_time = Costs::Duration(stats.hi());
|
||||
costs->min_execution_time = Costs::Duration(stats.lo());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
76
tensorflow/core/grappler/costs/measuring_cost_estimator.h
Normal file
76
tensorflow/core/grappler/costs/measuring_cost_estimator.h
Normal file
@ -0,0 +1,76 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_GRAPPLER_COSTS_MEASURING_COST_ESTIMATOR_H_
|
||||
#define TENSORFLOW_GRAPPLER_COSTS_MEASURING_COST_ESTIMATOR_H_
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/grappler/costs/cost_estimator.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
class CostGraphDef;
|
||||
class GraphDef;
|
||||
} // namespace tensorflow
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
class Cluster;
|
||||
struct GrapplerItem;
|
||||
|
||||
// Estimate the cost of running a Grappler item by actually running the
|
||||
// corresponding TensorFlow graph on the specified cluster and measuring the
|
||||
// runtimes.
|
||||
class MeasuringCostEstimator : public CostEstimator {
|
||||
public:
|
||||
// Run the model for measurement_steps to measure its average cost.
|
||||
// When measurement_threads is greater than 0, use a threadpool of as many
|
||||
// threads to run the measurements; otherwise, run them serially. Does not
|
||||
// take ownership of cluster.
|
||||
explicit MeasuringCostEstimator(Cluster* cluster, int measurement_steps,
|
||||
int measurement_threads);
|
||||
~MeasuringCostEstimator() override {}
|
||||
|
||||
// Initalizes the estimator for the specified grappler item.
|
||||
// This implementation always returns OK.
|
||||
Status Initialize(const GrapplerItem& item) override;
|
||||
|
||||
// Runs the optimized version of the graph on the cluster, measure
|
||||
// the runtimes of each operation, and annotated the CostGraphDef
|
||||
// with the corresponding measurements.
|
||||
// Returns the average latency for the whole graph.
|
||||
Status PredictCosts(const GraphDef& optimized_graph, CostGraphDef* cost_graph,
|
||||
Costs* overall_cost) const override;
|
||||
|
||||
private:
|
||||
Cluster* cluster_; // Not owned.
|
||||
int measurement_steps_;
|
||||
int measurement_threads_;
|
||||
std::vector<std::pair<string, Tensor>> feed_;
|
||||
std::vector<string> fetch_;
|
||||
std::unique_ptr<thread::ThreadPool> thread_pool_;
|
||||
};
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_GRAPPLER_COSTS_MEASURING_COST_ESTIMATOR_H_
|
152
tensorflow/core/grappler/costs/robust_stats.cc
Normal file
152
tensorflow/core/grappler/costs/robust_stats.cc
Normal file
@ -0,0 +1,152 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/costs/robust_stats.h"
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
// Given a sorted vector of values, calculate the median.
|
||||
// Returns 0 for an empty vector. Does not verify sortedness.
|
||||
static double SortedMedian(const std::vector<double> &values) {
|
||||
const int n = values.size();
|
||||
if (n == 0) return 0.0;
|
||||
if (n & 1) {
|
||||
return values[n / 2];
|
||||
} else {
|
||||
return (values[n / 2] + values[n / 2 - 1]) / 2.0;
|
||||
}
|
||||
}
|
||||
|
||||
// Given a vector of values (sorted or not), calculate the median.
|
||||
static double Median(std::vector<double> &&values) {
|
||||
const size_t n = values.size();
|
||||
if (n == 0) return 0;
|
||||
const auto middle = values.begin() + (n / 2);
|
||||
// Put the middle value in its place.
|
||||
std::nth_element(values.begin(), middle, values.end());
|
||||
if (n & 1) {
|
||||
return *middle;
|
||||
}
|
||||
// Return the average of the two elements, the max_element lower than
|
||||
// *middle is found between begin and middle as a post-cond of
|
||||
// nth_element.
|
||||
const auto lower_middle = std::max_element(values.begin(), middle);
|
||||
// Preventing overflow. We know that '*lower_middle <= *middle'.
|
||||
// If both are on oposite sides of zero, the sum won't overflow, otherwise
|
||||
// the difference won't overflow.
|
||||
if (*lower_middle <= 0 && *middle >= 0) {
|
||||
return (*lower_middle + *middle) / 2;
|
||||
}
|
||||
return *lower_middle + (*middle - *lower_middle) / 2;
|
||||
}
|
||||
|
||||
// Given a set of values, calculates the scaled Median Absolute Deviation (a
|
||||
// robust approximation to the standard deviation). This is calculated as the
|
||||
// median of the absolute deviations from the median, scaled by 1.4826. Its
|
||||
// advantage over the standard deviation is that it is not (as) affected by
|
||||
// outlier values. Returns a pair<median, mad>.
|
||||
static std::pair<double, double> ScaledMedianAbsoluteDeviation(
|
||||
const std::vector<double> &sorted_values) {
|
||||
double median = SortedMedian(sorted_values);
|
||||
|
||||
// Next, we calculate the absolute deviations from the median,
|
||||
// find the median of the resulting data, and scale by 1.4826.
|
||||
std::vector<double> deviations;
|
||||
deviations.reserve(sorted_values.size());
|
||||
for (double d : sorted_values) {
|
||||
deviations.push_back(std::abs(d - median));
|
||||
}
|
||||
double mad = Median(std::move(deviations)) * 1.4826;
|
||||
return std::pair<double, double>(median, mad);
|
||||
}
|
||||
|
||||
RobustStats::RobustStats(const std::vector<double> &values)
|
||||
: RobustStats(std::vector<double>(values)) {}
|
||||
|
||||
RobustStats::RobustStats(std::vector<double> &&values) {
|
||||
std::sort(values.begin(), values.end());
|
||||
lo_ = values[0];
|
||||
hi_ = values.back();
|
||||
HuberMAD(values);
|
||||
}
|
||||
|
||||
// Computes an updated mean using Huber's weighting function (values beyond
|
||||
// the margin are weighted by margin / abs(value - mean).
|
||||
double UpdateHuberMean(const std::vector<double> &sorted_values, double mean,
|
||||
double margin) {
|
||||
int num_within = 0;
|
||||
double sum = 0.0;
|
||||
|
||||
for (double d : sorted_values) {
|
||||
if (d < mean - margin) {
|
||||
sum -= margin;
|
||||
} else if (d > mean + margin) {
|
||||
sum += margin;
|
||||
} else {
|
||||
sum += d;
|
||||
++num_within;
|
||||
}
|
||||
}
|
||||
|
||||
// It is possible, for a set with an interquartile distance of 0, i.e., with
|
||||
// more than half of the values at the median, to encounter the case where
|
||||
// the Huber mean drifts slightly off the median and there are no values
|
||||
// within the margin. In that case, just return the old mean, and the caller
|
||||
// will quit.
|
||||
if (num_within > 0) {
|
||||
return sum / num_within;
|
||||
} else {
|
||||
return mean;
|
||||
}
|
||||
}
|
||||
|
||||
// Given a list of values, this approximates the stddev using the MAD and then
|
||||
// uses it to compute a Huber robust mean (sandwich mean). A margin of
|
||||
// c*stddev is defined around the current mean, and values are weighted by
|
||||
// margin / abs(value - mean) if outside the margin, or 1 if inside. This
|
||||
// computes the mean iteratively, because each time it changes the margin
|
||||
// shifts a bit. It typically settles very quickly, but it's possible for it
|
||||
// to be unstable. We limit it to 10 iterations.
|
||||
//
|
||||
void RobustStats::HuberMAD(const std::vector<double> &sorted_values) {
|
||||
const std::pair<double, double> median_mad =
|
||||
ScaledMedianAbsoluteDeviation(sorted_values);
|
||||
mean_ = median_mad.first;
|
||||
stddev_ = median_mad.second;
|
||||
|
||||
// c = 1.345 is the commonly used cutoff with 95% efficiency at the normal.
|
||||
// We're using c = 1.5 to be a little more conservative, and because that's
|
||||
// the default in S-plus.
|
||||
// TODO(dehnert): Specialize Stats for integral types so we don't implement
|
||||
// methods that don't make sense.
|
||||
const double c = 1.5;
|
||||
const double margin = c * stddev_;
|
||||
|
||||
// Iterate 10 times, or until the Huber mean stabilizes.
|
||||
// If the margin is zero, we don't want mean to drift from the median.
|
||||
if (margin > 0.0) {
|
||||
for (int k = 0; k < 10; ++k) {
|
||||
double old_mean = mean_;
|
||||
mean_ = UpdateHuberMean(sorted_values, mean_, margin);
|
||||
if (mean_ == old_mean) break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
42
tensorflow/core/grappler/costs/robust_stats.h
Normal file
42
tensorflow/core/grappler/costs/robust_stats.h
Normal file
@ -0,0 +1,42 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_GRAPPLER_COSTS_ROBUST_STATS_H_
|
||||
#define TENSORFLOW_GRAPPLER_COSTS_ROBUST_STATS_H_
|
||||
|
||||
#include <vector>
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
class RobustStats {
|
||||
public:
|
||||
RobustStats(const std::vector<double>& values);
|
||||
RobustStats(std::vector<double>&& values);
|
||||
|
||||
double lo() const { return lo_; }
|
||||
double hi() const { return hi_; }
|
||||
double mean() const { return mean_; }
|
||||
|
||||
private:
|
||||
void HuberMAD(const std::vector<double>& values);
|
||||
|
||||
double lo_;
|
||||
double hi_;
|
||||
double mean_;
|
||||
double stddev_;
|
||||
};
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_GRAPPLER_COSTS_ROBUST_STATS_H_
|
63
tensorflow/core/grappler/costs/robust_stats_test.cc
Normal file
63
tensorflow/core/grappler/costs/robust_stats_test.cc
Normal file
@ -0,0 +1,63 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/costs/robust_stats.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
class RobustStatsTest : public ::testing::Test {
|
||||
public:
|
||||
void SetUp() override {
|
||||
for (double d = 1.0; d <= 5.0; d += 1.0) {
|
||||
values1_.push_back(5.0 - d);
|
||||
values1_.push_back(5.0 + d);
|
||||
values2_.push_back(25.0 - 2 * d);
|
||||
values2_.push_back(25.0 + 2 * d);
|
||||
values3_.push_back(-3.0 - d);
|
||||
values3_.push_back(-3.0 + d);
|
||||
}
|
||||
values1_.push_back(5.0); // Odd # elements, mean is 5.0
|
||||
values3_.push_back(197.0);
|
||||
values3_.push_back(-203.0); // Even # elements, mean is -3.0
|
||||
}
|
||||
|
||||
std::vector<double> values1_;
|
||||
std::vector<double> values2_;
|
||||
std::vector<double> values3_;
|
||||
};
|
||||
|
||||
TEST_F(RobustStatsTest, Simple) {
|
||||
RobustStats s1(values1_);
|
||||
EXPECT_EQ(5.0, s1.mean());
|
||||
EXPECT_EQ(0.0, s1.lo());
|
||||
EXPECT_EQ(10.0, s1.hi());
|
||||
|
||||
RobustStats s2(values2_);
|
||||
EXPECT_EQ(25.0, s2.mean());
|
||||
EXPECT_EQ(15.0, s2.lo());
|
||||
EXPECT_EQ(35.0, s2.hi());
|
||||
|
||||
RobustStats s3(values3_);
|
||||
EXPECT_EQ(-3.0, s3.mean());
|
||||
EXPECT_EQ(-203.0, s3.lo());
|
||||
EXPECT_EQ(197.0, s3.hi());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
215
tensorflow/core/grappler/costs/virtual_scheduler.cc
Normal file
215
tensorflow/core/grappler/costs/virtual_scheduler.cc
Normal file
@ -0,0 +1,215 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/costs/virtual_scheduler.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
Costs CombineCosts(const Costs& left, const Costs& right) {
|
||||
CHECK_NE(left.max_memory, kMemoryUnknown);
|
||||
CHECK_NE(left.max_per_op_buffers, kMemoryUnknown);
|
||||
CHECK_NE(left.max_per_op_streaming, kMemoryUnknown);
|
||||
|
||||
Costs result = left;
|
||||
result.execution_time += right.execution_time;
|
||||
if (right.max_memory != kMemoryUnknown) {
|
||||
result.max_memory += right.max_memory;
|
||||
}
|
||||
if (right.max_per_op_buffers != kMemoryUnknown) {
|
||||
result.max_per_op_buffers =
|
||||
std::max(left.max_per_op_buffers, right.max_per_op_buffers);
|
||||
}
|
||||
if (right.max_per_op_streaming != kMemoryUnknown) {
|
||||
result.max_per_op_streaming =
|
||||
std::max(left.max_per_op_streaming, right.max_per_op_streaming);
|
||||
}
|
||||
VLOG(2) << "costs execution_time=" << result.execution_time.count()
|
||||
<< " max_memory=" << result.max_memory
|
||||
<< " max_per_op_buffers=" << result.max_per_op_buffers
|
||||
<< " max_per_op_streaming=" << result.max_per_op_streaming;
|
||||
return result;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
VirtualScheduler::VirtualScheduler(const GraphDef& graph,
|
||||
const std::vector<string>& fetch_nodes)
|
||||
: graph_costs_(Costs::ZeroCosts()),
|
||||
// TODO(dyoon): Use a better way than FIFO.
|
||||
ready_nodes_(new FIFOManager()) {
|
||||
// First, get the nodes that would run to output fetch_nodes.
|
||||
std::vector<const NodeDef*> nodes =
|
||||
ComputeTransitiveFanin(graph, fetch_nodes);
|
||||
|
||||
// TODO(dyoon): this is a bit inefficient as name_to_node is already built in
|
||||
// ComputeTransitiveFanin().
|
||||
std::unordered_map<string, const NodeDef*> name_to_node;
|
||||
for (const auto& node : graph.node()) {
|
||||
name_to_node[node.name()] = &node;
|
||||
}
|
||||
|
||||
// Build node_map.
|
||||
for (const auto* node : nodes) {
|
||||
auto& node_state = GetNodeStateOrCreateIt(node);
|
||||
// TODO(dyoon): add SendRecv considering devices and control dependency.
|
||||
for (const string& input : node->input()) {
|
||||
const NodeDef* in = name_to_node[NodeName(input)];
|
||||
CHECK(in);
|
||||
node_state.inputs.push_back(in);
|
||||
auto& input_node_state = GetNodeStateOrCreateIt(in);
|
||||
input_node_state.outputs.push_back(node);
|
||||
}
|
||||
if (node->input().empty()) {
|
||||
node_state.time_ready =
|
||||
Costs::Duration(); // Node without input: ready at time 0.
|
||||
ready_nodes_->AddNode(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const NodeDef* VirtualScheduler::GetCurrNode() const {
|
||||
return ready_nodes_->GetCurrNode();
|
||||
}
|
||||
|
||||
NodeState& VirtualScheduler::GetNodeStateOrCreateIt(const NodeDef* node) {
|
||||
auto it = node_map_.find(node);
|
||||
if (it == node_map_.end()) {
|
||||
it = node_map_.emplace(node, NodeState()).first;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
|
||||
// Update graph_costs_ and per-op costs.
|
||||
graph_costs_ = CombineCosts(graph_costs_, node_costs);
|
||||
const auto* node = GetCurrNode();
|
||||
const auto& op_name = node->op();
|
||||
|
||||
auto it = op_to_cost_.find(op_name);
|
||||
if (it == op_to_cost_.end()) {
|
||||
it = op_to_cost_.emplace(op_name, Costs::ZeroCosts()).first;
|
||||
}
|
||||
auto& op_cost = it->second;
|
||||
op_cost = CombineCosts(op_cost, node_costs);
|
||||
|
||||
// Update node and device states.
|
||||
auto& node_state = node_map_[node];
|
||||
auto& device = device_[node->device()];
|
||||
device.nodes_executed.push_back(node);
|
||||
// Node is scheduled when the device is available AND all the inputs are
|
||||
// ready; hence, time_scheduled is time_ready if time_ready > device curr
|
||||
// time.
|
||||
node_state.time_scheduled =
|
||||
std::max(device.GetCurrTime(), node_state.time_ready);
|
||||
// Override device curr time with the time_scheduled.
|
||||
device.device_costs.execution_time = node_state.time_scheduled;
|
||||
device.device_costs = CombineCosts(device.device_costs, node_costs);
|
||||
auto curr_time = device.GetCurrTime();
|
||||
node_state.time_finished = curr_time;
|
||||
|
||||
// Update device's per-op cost.
|
||||
{
|
||||
auto it = device.op_to_cost.find(op_name);
|
||||
if (it == device.op_to_cost.end()) {
|
||||
it = device.op_to_cost.emplace(op_name, Costs::ZeroCosts()).first;
|
||||
}
|
||||
auto& op_cost = it->second;
|
||||
op_cost = CombineCosts(op_cost, node_costs);
|
||||
|
||||
VLOG(2) << "Op scheduled -- name: " << node->name()
|
||||
<< ", op: " << node->op() << ", device: " << node->device()
|
||||
<< ", ready: " << node_state.time_ready.count()
|
||||
<< ", scheduled: " << node_state.time_scheduled.count()
|
||||
<< ", finished: " << node_state.time_finished.count();
|
||||
|
||||
// Increment num_inputs_ready of the output nodes.
|
||||
for (auto* output : node_state.outputs) {
|
||||
auto& output_state = node_map_[output];
|
||||
output_state.num_inputs_ready++;
|
||||
if (output_state.num_inputs_ready == output_state.inputs.size()) {
|
||||
// This output node is now ready.
|
||||
output_state.time_ready = curr_time;
|
||||
ready_nodes_->AddNode(output);
|
||||
}
|
||||
}
|
||||
|
||||
// Increment num_outputs_executed of the input nodes.
|
||||
for (auto* input : node_state.inputs) {
|
||||
auto& input_state = node_map_[input];
|
||||
input_state.num_outputs_executed++;
|
||||
if (input_state.num_outputs_executed == input_state.outputs.size()) {
|
||||
// All the outputs are executed; no reference to this input nodel
|
||||
input_state.time_no_reference = curr_time;
|
||||
// TODO(dyoon): collect device memory usage; note that this input node
|
||||
// use device memory between time_scheduled and time_no_reference.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove the current node; assume FIFO.
|
||||
ready_nodes_->RemoveCurrNode();
|
||||
return !ready_nodes_->Empty(); // True if not empty.
|
||||
}
|
||||
|
||||
Costs VirtualScheduler::Summary() const {
|
||||
// Print out basic execution summary.
|
||||
VLOG(1) << "Expected execution time: " << graph_costs_.execution_time.count();
|
||||
VLOG(1) << "Expected max memory: " << graph_costs_.max_memory;
|
||||
VLOG(1) << "Expected max per-op buffers: " << graph_costs_.max_per_op_buffers;
|
||||
VLOG(1) << "Expected max per-op streaming buffers: "
|
||||
<< graph_costs_.max_per_op_streaming;
|
||||
|
||||
VLOG(1) << "Per-op execution time:";
|
||||
for (const auto& op_cost_pair : op_to_cost_) {
|
||||
const auto& op = op_cost_pair.first;
|
||||
const auto& cost = op_cost_pair.second.execution_time.count();
|
||||
if (cost) { // Skip printing out zero-cost ops.
|
||||
VLOG(1) << " + " << op << " : " << cost;
|
||||
}
|
||||
}
|
||||
|
||||
// Print per device summary
|
||||
VLOG(1) << "Devices:";
|
||||
Costs critical_path_costs = Costs::ZeroCosts();
|
||||
|
||||
for (const auto& device : device_) {
|
||||
const auto& name = device.first;
|
||||
const auto& state = device.second;
|
||||
VLOG(1) << "Device = " << name
|
||||
<< ", num_nodes = " << state.nodes_executed.size()
|
||||
<< ", execution_time = " << state.GetCurrTime().count();
|
||||
VLOG(1) << "Per-op execution time:";
|
||||
for (const auto& op_cost_pair : state.op_to_cost) {
|
||||
const auto& op = op_cost_pair.first;
|
||||
const auto& cost = op_cost_pair.second.execution_time.count();
|
||||
if (cost) { // Skip printing out zero-cost ops.
|
||||
VLOG(1) << " + " << op << " : " << cost;
|
||||
}
|
||||
}
|
||||
if (critical_path_costs.execution_time <= state.GetCurrTime()) {
|
||||
critical_path_costs = state.device_costs;
|
||||
}
|
||||
}
|
||||
|
||||
VLOG(1) << "Critical path execution time: "
|
||||
<< critical_path_costs.execution_time.count();
|
||||
return critical_path_costs;
|
||||
}
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
116
tensorflow/core/grappler/costs/virtual_scheduler.h
Normal file
116
tensorflow/core/grappler/costs/virtual_scheduler.h
Normal file
@ -0,0 +1,116 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_
|
||||
|
||||
#include <list>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/core/grappler/costs/cost_estimator.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
namespace {
|
||||
struct NodeState {
|
||||
std::vector<const NodeDef*> inputs;
|
||||
std::vector<const NodeDef*> outputs;
|
||||
int num_inputs_ready;
|
||||
int num_outputs_executed;
|
||||
Costs::Duration time_ready;
|
||||
Costs::Duration time_scheduled;
|
||||
Costs::Duration time_finished;
|
||||
Costs::Duration time_no_reference;
|
||||
|
||||
// Node will be ready to be executed at time_ready, scheduled at
|
||||
// time_scheduled, and finishes execution at time_finished.
|
||||
// Between time_scheduled and time_no_reference, the node's output tensor
|
||||
// needs to be on the device, using up device memory.
|
||||
|
||||
NodeState() {
|
||||
num_inputs_ready = 0;
|
||||
num_outputs_executed = 0;
|
||||
time_ready = Costs::Duration::max();
|
||||
time_scheduled = Costs::Duration::max();
|
||||
time_finished = Costs::Duration::max();
|
||||
time_no_reference = Costs::Duration::max();
|
||||
}
|
||||
};
|
||||
|
||||
struct DeviceState {
|
||||
std::vector<const NodeDef*> nodes_executed;
|
||||
Costs device_costs;
|
||||
std::map<string, Costs> op_to_cost; // Per-op cost.
|
||||
|
||||
DeviceState() { device_costs = Costs::ZeroCosts(); }
|
||||
|
||||
Costs::Duration GetCurrTime() const { return device_costs.execution_time; }
|
||||
};
|
||||
|
||||
// ReadyNodeManager (abstract class):
|
||||
// Keeps ready nodes and picks the best one to be scheduled.
|
||||
class ReadyNodeManager {
|
||||
public:
|
||||
ReadyNodeManager() {}
|
||||
virtual ~ReadyNodeManager() {}
|
||||
virtual void AddNode(const NodeDef* node) = 0;
|
||||
virtual const NodeDef* GetCurrNode() const = 0;
|
||||
virtual void RemoveCurrNode() = 0;
|
||||
virtual bool Empty() const = 0;
|
||||
};
|
||||
|
||||
class FIFOManager : public ReadyNodeManager {
|
||||
public:
|
||||
FIFOManager() : ReadyNodeManager() {}
|
||||
~FIFOManager() override {}
|
||||
void AddNode(const NodeDef* node) override { nodes_.push_back(node); }
|
||||
const NodeDef* GetCurrNode() const override { return nodes_.front(); }
|
||||
void RemoveCurrNode() override { nodes_.pop_front(); }
|
||||
bool Empty() const override { return nodes_.empty(); }
|
||||
|
||||
private:
|
||||
std::list<const NodeDef*> nodes_;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// The virtual scheduler emulates execution of nodes in a graph, considering
|
||||
// dependencies, device, etc.
|
||||
class VirtualScheduler {
|
||||
public:
|
||||
VirtualScheduler(const GraphDef& graph,
|
||||
const std::vector<string>& fetch_nodes);
|
||||
|
||||
const NodeDef* GetCurrNode() const;
|
||||
bool MarkCurrNodeExecuted(const Costs& node_costs);
|
||||
|
||||
Costs Summary() const;
|
||||
|
||||
private:
|
||||
NodeState& GetNodeStateOrCreateIt(const NodeDef* node);
|
||||
|
||||
Costs graph_costs_; // Graph cost.
|
||||
std::map<string, Costs> op_to_cost_; // Per-op cost.
|
||||
std::unique_ptr<ReadyNodeManager> ready_nodes_;
|
||||
std::unordered_map<const NodeDef*, NodeState> node_map_;
|
||||
std::unordered_map<string, DeviceState> device_;
|
||||
};
|
||||
|
||||
} // namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_
|
@ -2109,7 +2109,9 @@ tf_kernel_library(
|
||||
tf_kernel_library(
|
||||
name = "matrix_triangular_solve_op",
|
||||
prefix = "matrix_triangular_solve_op",
|
||||
deps = LINALG_DEPS,
|
||||
deps = LINALG_DEPS + if_cuda([
|
||||
"//tensorflow/core/platform/default/build_config:cublas_plugin",
|
||||
]),
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
@ -2350,6 +2352,8 @@ tf_kernel_library(
|
||||
"//conditions:default": [],
|
||||
}) + if_mkl([
|
||||
"//third_party/mkl:intel_binary_blob",
|
||||
]) + if_cuda([
|
||||
"//tensorflow/core/platform/default/build_config:cublas_plugin",
|
||||
]),
|
||||
)
|
||||
|
||||
@ -2630,6 +2634,7 @@ tf_kernel_library(
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}) + if_cuda([
|
||||
"//tensorflow/core/platform/default/build_config:cublas_plugin",
|
||||
"//tensorflow/core/platform/default/build_config:cudnn_plugin",
|
||||
]),
|
||||
)
|
||||
|
@ -24,28 +24,32 @@ limitations under the License.
|
||||
|
||||
#if !defined(_MSC_VER)
|
||||
#define UNROLL _Pragma("unroll")
|
||||
#define NOUNROLL _Pragma("nounroll")
|
||||
#else
|
||||
#define UNROLL
|
||||
#define NOUNROLL
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
using Eigen::GpuDevice;
|
||||
|
||||
// A Cuda kernel to compute the depthwise convolution forward pass
|
||||
// in NHWC format.
|
||||
template <typename T>
|
||||
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
|
||||
int kKnownDepthMultiplier>
|
||||
__global__ void DepthwiseConv2dGPUKernelNHWC(const DepthwiseArgs args,
|
||||
const T* input, const T* filter,
|
||||
T* output, int num_outputs) {
|
||||
const int in_rows = args.in_rows;
|
||||
const int in_cols = args.in_cols;
|
||||
const int in_depth = args.in_depth;
|
||||
const int filter_rows = args.filter_rows;
|
||||
const int filter_cols = args.filter_cols;
|
||||
const int depth_multiplier = args.depth_multiplier;
|
||||
const int filter_rows =
|
||||
kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight;
|
||||
const int filter_cols =
|
||||
kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth;
|
||||
const int depth_multiplier =
|
||||
kKnownDepthMultiplier < 0 ? args.depth_multiplier : kKnownDepthMultiplier;
|
||||
const int stride = args.stride;
|
||||
const int pad_rows = args.pad_rows;
|
||||
const int pad_cols = args.pad_cols;
|
||||
@ -114,16 +118,20 @@ __global__ void DepthwiseConv2dGPUKernelNHWC(const DepthwiseArgs args,
|
||||
|
||||
// A Cuda kernel to compute the depthwise convolution forward pass
|
||||
// in NCHW format.
|
||||
template <typename T>
|
||||
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
|
||||
int kKnownDepthMultiplier>
|
||||
__global__ void DepthwiseConv2dGPUKernelNCHW(const DepthwiseArgs args,
|
||||
const T* input, const T* filter,
|
||||
T* output, int num_outputs) {
|
||||
const int in_rows = args.in_rows;
|
||||
const int in_cols = args.in_cols;
|
||||
const int in_depth = args.in_depth;
|
||||
const int filter_rows = args.filter_rows;
|
||||
const int filter_cols = args.filter_cols;
|
||||
const int depth_multiplier = args.depth_multiplier;
|
||||
const int filter_rows =
|
||||
kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight;
|
||||
const int filter_cols =
|
||||
kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth;
|
||||
const int depth_multiplier =
|
||||
kKnownDepthMultiplier < 0 ? args.depth_multiplier : kKnownDepthMultiplier;
|
||||
const int stride = args.stride;
|
||||
const int pad_rows = args.pad_rows;
|
||||
const int pad_cols = args.pad_cols;
|
||||
@ -235,49 +243,63 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(const DepthwiseArgs args,
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// A simple launch pad to launch the Cuda kernel for depthwise convolution.
|
||||
template <typename T>
|
||||
struct DepthwiseConv2dGPULaunch {
|
||||
static void Run(const GPUDevice& d, const DepthwiseArgs args, const T* input,
|
||||
const T* filter, T* output, TensorFormat data_format) {
|
||||
// In this kernel, each thread is computing the gradients from one element
|
||||
// in the out_backprop. Note that one element in the out_backprop can map
|
||||
// to multiple filter elements.
|
||||
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
|
||||
int kKnownDepthMultiplier>
|
||||
void LaunchDepthwiseConv2dGPU(const GpuDevice& d, const DepthwiseArgs args,
|
||||
const T* input, const T* filter, T* output,
|
||||
TensorFormat data_format) {
|
||||
const int num_outputs =
|
||||
args.batch * args.out_rows * args.out_cols * args.out_depth;
|
||||
CudaLaunchConfig config = GetCudaLaunchConfig(num_outputs, d);
|
||||
if (data_format == FORMAT_NHWC) {
|
||||
DepthwiseConv2dGPUKernelNHWC<T>
|
||||
DepthwiseConv2dGPUKernelNHWC<T, kKnownFilterWidth, kKnownFilterHeight,
|
||||
kKnownDepthMultiplier>
|
||||
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||
args, input, filter, output, num_outputs);
|
||||
} else if (data_format == FORMAT_NCHW) {
|
||||
DepthwiseConv2dGPUKernelNCHW<T>
|
||||
DepthwiseConv2dGPUKernelNCHW<T, kKnownFilterWidth, kKnownFilterHeight,
|
||||
kKnownDepthMultiplier>
|
||||
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||
args, input, filter, output, num_outputs);
|
||||
} else {
|
||||
assert(false);
|
||||
}
|
||||
}
|
||||
|
||||
// A simple launch pad to launch the Cuda kernel for depthwise convolution.
|
||||
template <typename T>
|
||||
struct DepthwiseConv2dGPULaunch {
|
||||
static void Run(const GpuDevice& d, const DepthwiseArgs args, const T* input,
|
||||
const T* filter, T* output, TensorFormat data_format) {
|
||||
if (args.filter_rows == 3 && args.filter_cols == 3 &&
|
||||
args.depth_multiplier == 1) {
|
||||
LaunchDepthwiseConv2dGPU<T, 3, 3, 1>(d, args, input, filter, output,
|
||||
data_format);
|
||||
} else {
|
||||
LaunchDepthwiseConv2dGPU<T, -1, -1, -1>(d, args, input, filter, output,
|
||||
data_format);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template struct DepthwiseConv2dGPULaunch<float>;
|
||||
template struct DepthwiseConv2dGPULaunch<double>;
|
||||
|
||||
// A Cuda kernel to compute the depthwise convolution backprop w.r.t. input.
|
||||
template <typename T, int KNOWN_DEPTH_MULTIPLIER>
|
||||
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
|
||||
int kKnownDepthMultiplier>
|
||||
__global__ void DepthwiseConv2dBackpropInputGPUKernelNHWC(
|
||||
const DepthwiseArgs args, const T* out_backprop, const T* filter,
|
||||
T* in_backprop, int num_in_backprop) {
|
||||
const int in_rows = args.in_rows;
|
||||
const int in_cols = args.in_cols;
|
||||
const int in_depth = args.in_depth;
|
||||
const int filter_rows = args.filter_rows;
|
||||
const int filter_cols = args.filter_cols;
|
||||
const int depth_multiplier = KNOWN_DEPTH_MULTIPLIER == -1
|
||||
? args.depth_multiplier
|
||||
: KNOWN_DEPTH_MULTIPLIER;
|
||||
const int filter_rows =
|
||||
kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight;
|
||||
const int filter_cols =
|
||||
kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth;
|
||||
const int depth_multiplier =
|
||||
kKnownDepthMultiplier < 0 ? args.depth_multiplier : kKnownDepthMultiplier;
|
||||
const int stride = args.stride;
|
||||
const int pad_rows = args.pad_rows;
|
||||
const int pad_cols = args.pad_cols;
|
||||
@ -301,14 +323,12 @@ __global__ void DepthwiseConv2dBackpropInputGPUKernelNHWC(
|
||||
tf_max(0, (in_c - filter_cols + pad_cols + stride) / stride);
|
||||
const int out_c_end = tf_min(out_cols - 1, (in_c + pad_cols) / stride);
|
||||
|
||||
#pragma nounroll
|
||||
for (int out_r = out_r_start; out_r <= out_r_end; ++out_r) {
|
||||
NOUNROLL for (int out_r = out_r_start; out_r <= out_r_end; ++out_r) {
|
||||
const int f_r = in_r + pad_rows - out_r * stride;
|
||||
const int temp_out_backprop_offset =
|
||||
out_depth * out_cols * (out_r + out_rows * b);
|
||||
const int temp_filter_offset = filter_cols * f_r;
|
||||
#pragma nounroll
|
||||
for (int out_c = out_c_start; out_c <= out_c_end; ++out_c) {
|
||||
NOUNROLL for (int out_c = out_c_start; out_c <= out_c_end; ++out_c) {
|
||||
const int f_c = in_c + pad_cols - out_c * stride;
|
||||
int filter_offset =
|
||||
depth_multiplier * (in_d + in_depth * (f_c + temp_filter_offset));
|
||||
@ -328,7 +348,8 @@ __global__ void DepthwiseConv2dBackpropInputGPUKernelNHWC(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
|
||||
int kKnownDepthMultiplier>
|
||||
__global__ void __launch_bounds__(1024)
|
||||
DepthwiseConv2dBackpropInputGPUKernelNCHW(const DepthwiseArgs args,
|
||||
const T* out_backprop,
|
||||
@ -337,9 +358,12 @@ __global__ void __launch_bounds__(1024)
|
||||
const int in_rows = args.in_rows;
|
||||
const int in_cols = args.in_cols;
|
||||
const int in_depth = args.in_depth;
|
||||
const int filter_rows = args.filter_rows;
|
||||
const int filter_cols = args.filter_cols;
|
||||
const int depth_multiplier = args.depth_multiplier;
|
||||
const int filter_rows =
|
||||
kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight;
|
||||
const int filter_cols =
|
||||
kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth;
|
||||
const int depth_multiplier =
|
||||
kKnownDepthMultiplier < 0 ? args.depth_multiplier : kKnownDepthMultiplier;
|
||||
const int stride = args.stride;
|
||||
const int pad_rows = args.pad_rows;
|
||||
const int pad_cols = args.pad_cols;
|
||||
@ -395,52 +419,74 @@ __global__ void __launch_bounds__(1024)
|
||||
}
|
||||
}
|
||||
|
||||
// A simple launch pad to launch the Cuda kernel for depthwise convolution.
|
||||
template <typename T>
|
||||
struct DepthwiseConv2dBackpropInputGPULaunch {
|
||||
static void Run(const GPUDevice& d, const DepthwiseArgs args,
|
||||
const T* out_backprop, const T* filter, T* in_backprop,
|
||||
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
|
||||
int kKnownDepthMultiplier>
|
||||
void LaunchDepthwiseConv2dBackpropInputGPU(const GpuDevice& d,
|
||||
const DepthwiseArgs args,
|
||||
const T* out_backprop,
|
||||
const T* filter, T* in_backprop,
|
||||
TensorFormat data_format) {
|
||||
const int num_in_backprop =
|
||||
args.batch * args.in_rows * args.in_cols * args.in_depth;
|
||||
|
||||
CudaLaunchConfig config = GetCudaLaunchConfig(num_in_backprop, d);
|
||||
// Increase block count for when there are more warps/SM than threads/SM.
|
||||
// TODO(csigg): this is pretty arbitraty and should be generalized using
|
||||
// cudaOccupancyMaxPotentialBlockSize().
|
||||
config.block_count *= 4;
|
||||
if (data_format == FORMAT_NHWC) {
|
||||
if (args.depth_multiplier == 1) {
|
||||
DepthwiseConv2dBackpropInputGPUKernelNHWC<T, 1>
|
||||
DepthwiseConv2dBackpropInputGPUKernelNHWC<
|
||||
T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>
|
||||
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||
args, out_backprop, filter, in_backprop, num_in_backprop);
|
||||
} else {
|
||||
DepthwiseConv2dBackpropInputGPUKernelNHWC<T, -1>
|
||||
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||
args, out_backprop, filter, in_backprop, num_in_backprop);
|
||||
}
|
||||
} else if (data_format == FORMAT_NCHW) {
|
||||
DepthwiseConv2dBackpropInputGPUKernelNCHW<T>
|
||||
DepthwiseConv2dBackpropInputGPUKernelNCHW<
|
||||
T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>
|
||||
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||
args, out_backprop, filter, in_backprop, num_in_backprop);
|
||||
} else {
|
||||
assert(false);
|
||||
}
|
||||
}
|
||||
|
||||
// A simple launch pad to launch the Cuda kernel for depthwise convolution.
|
||||
template <typename T>
|
||||
struct DepthwiseConv2dBackpropInputGPULaunch {
|
||||
static void Run(const GpuDevice& d, const DepthwiseArgs args,
|
||||
const T* out_backprop, const T* filter, T* in_backprop,
|
||||
TensorFormat data_format) {
|
||||
if (args.depth_multiplier == 1) {
|
||||
if (args.filter_rows == 3 && args.filter_cols == 3) {
|
||||
LaunchDepthwiseConv2dBackpropInputGPU<T, 3, 3, 1>(
|
||||
d, args, out_backprop, filter, in_backprop, data_format);
|
||||
} else {
|
||||
LaunchDepthwiseConv2dBackpropInputGPU<T, -1, -1, 1>(
|
||||
d, args, out_backprop, filter, in_backprop, data_format);
|
||||
}
|
||||
} else {
|
||||
LaunchDepthwiseConv2dBackpropInputGPU<T, -1, -1, -1>(
|
||||
d, args, out_backprop, filter, in_backprop, data_format);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template struct DepthwiseConv2dBackpropInputGPULaunch<float>;
|
||||
template struct DepthwiseConv2dBackpropInputGPULaunch<double>;
|
||||
|
||||
// A Cuda kernel to compute the depthwise convolution backprop w.r.t. filter.
|
||||
template <typename T>
|
||||
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
|
||||
int kKnownDepthMultiplier>
|
||||
__global__ void DepthwiseConv2dBackpropFilterGPUKernelNHWC(
|
||||
const DepthwiseArgs args, const T* out_backprop, const T* input,
|
||||
T* filter_backprop, int num_out_backprop) {
|
||||
const int in_rows = args.in_rows;
|
||||
const int in_cols = args.in_cols;
|
||||
const int in_depth = args.in_depth;
|
||||
const int filter_rows = args.filter_rows;
|
||||
const int filter_cols = args.filter_cols;
|
||||
const int depth_multiplier = args.depth_multiplier;
|
||||
const int filter_rows =
|
||||
kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight;
|
||||
const int filter_cols =
|
||||
kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth;
|
||||
const int depth_multiplier =
|
||||
kKnownDepthMultiplier < 0 ? args.depth_multiplier : kKnownDepthMultiplier;
|
||||
const int stride = args.stride;
|
||||
const int pad_rows = args.pad_rows;
|
||||
const int pad_cols = args.pad_cols;
|
||||
@ -518,16 +564,20 @@ __global__ void DepthwiseConv2dBackpropFilterGPUKernelNHWC(
|
||||
}
|
||||
|
||||
// A Cuda kernel to compute the depthwise convolution backprop w.r.t. filter.
|
||||
template <typename T>
|
||||
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
|
||||
int kKnownDepthMultiplier>
|
||||
__global__ void DepthwiseConv2dBackpropFilterGPUKernelNCHW(
|
||||
const DepthwiseArgs args, const T* out_backprop, const T* input,
|
||||
T* filter_backprop, int num_out_backprop) {
|
||||
const int in_rows = args.in_rows;
|
||||
const int in_cols = args.in_cols;
|
||||
const int in_depth = args.in_depth;
|
||||
const int filter_rows = args.filter_rows;
|
||||
const int filter_cols = args.filter_cols;
|
||||
const int depth_multiplier = args.depth_multiplier;
|
||||
const int filter_rows =
|
||||
kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight;
|
||||
const int filter_cols =
|
||||
kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth;
|
||||
const int depth_multiplier =
|
||||
kKnownDepthMultiplier < 0 ? args.depth_multiplier : kKnownDepthMultiplier;
|
||||
const int stride = args.stride;
|
||||
const int pad_rows = args.pad_rows;
|
||||
const int pad_cols = args.pad_cols;
|
||||
@ -610,30 +660,46 @@ __global__ void DepthwiseConv2dBackpropFilterGPUKernelNCHW(
|
||||
}
|
||||
}
|
||||
|
||||
// A simple launch pad to launch the Cuda kernel for depthwise convolution.
|
||||
template <typename T>
|
||||
struct DepthwiseConv2dBackpropFilterGPULaunch {
|
||||
static void Run(const GPUDevice& d, const DepthwiseArgs args,
|
||||
const T* out_backprop, const T* input, T* filter_backprop,
|
||||
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
|
||||
int kKnownDepthMultiplier>
|
||||
void LaunchDepthwiseConv2dBackpropFilterGPU(const GpuDevice& d,
|
||||
const DepthwiseArgs args,
|
||||
const T* out_backprop,
|
||||
const T* input, T* filter_backprop,
|
||||
TensorFormat data_format) {
|
||||
// In this kernel, each thread is computing the gradients for one element in
|
||||
// the out_backprop.
|
||||
const int num_out_backprop =
|
||||
args.batch * args.out_rows * args.out_cols * args.out_depth;
|
||||
CudaLaunchConfig config = GetCudaLaunchConfig(num_out_backprop, d);
|
||||
|
||||
if (data_format == FORMAT_NHWC) {
|
||||
DepthwiseConv2dBackpropFilterGPUKernelNHWC<T>
|
||||
DepthwiseConv2dBackpropFilterGPUKernelNHWC<
|
||||
T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>
|
||||
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||
args, out_backprop, input, filter_backprop, num_out_backprop);
|
||||
} else if (data_format == FORMAT_NCHW) {
|
||||
DepthwiseConv2dBackpropFilterGPUKernelNCHW<T>
|
||||
DepthwiseConv2dBackpropFilterGPUKernelNCHW<
|
||||
T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>
|
||||
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||
args, out_backprop, input, filter_backprop, num_out_backprop);
|
||||
} else {
|
||||
assert(false);
|
||||
}
|
||||
}
|
||||
|
||||
// A simple launch pad to launch the Cuda kernel for depthwise convolution.
|
||||
template <typename T>
|
||||
struct DepthwiseConv2dBackpropFilterGPULaunch {
|
||||
static void Run(const GpuDevice& d, const DepthwiseArgs args,
|
||||
const T* out_backprop, const T* input, T* filter_backprop,
|
||||
TensorFormat data_format) {
|
||||
if (args.filter_rows == 3 && args.filter_cols == 3 &&
|
||||
args.depth_multiplier == 1) {
|
||||
LaunchDepthwiseConv2dBackpropFilterGPU<T, 3, 3, 1>(
|
||||
d, args, out_backprop, input, filter_backprop, data_format);
|
||||
} else {
|
||||
LaunchDepthwiseConv2dBackpropFilterGPU<T, -1, -1, -1>(
|
||||
d, args, out_backprop, input, filter_backprop, data_format);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template struct DepthwiseConv2dBackpropFilterGPULaunch<float>;
|
||||
|
@ -220,6 +220,10 @@ And the following comand line executes the `HelloTF` program on Windows:
|
||||
|
||||
<pre><b>java -cp libtensorflow-1.1.0-rc2.jar;. -Djava.library.path=jni HelloTF</b></pre>
|
||||
|
||||
And the following comand line executes the `HelloTF` program on Windows:
|
||||
|
||||
<pre><b>java -cp libtensorflow-1.1.0-rc2.jar;. -Djava.library.path=jni HelloTF</b></pre>
|
||||
|
||||
If the program prints <tt>Hello from <i>version</i></tt>, you've successfully
|
||||
installed TensorFlow for Java and are ready to use the API. If the program
|
||||
outputs something else, check
|
||||
|
@ -3,9 +3,9 @@
|
||||
## Overview
|
||||
|
||||
A selection of image classification models were tested across multiple platforms
|
||||
to create a point of reference for the TensorFlow community. The methodology,
|
||||
links to the benchmark scripts, and commands to reproduce the results are in the
|
||||
[Appendix](#appendix).
|
||||
to create a point of reference for the TensorFlow community. The
|
||||
[Methodology](#methodology) section details how the test were executed and has
|
||||
links to the scripts used.
|
||||
|
||||
## Results for image classification models
|
||||
|
||||
@ -120,19 +120,19 @@ VGG16 | replicated (with NCCL) | n/a
|
||||
|
||||
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
|
||||
---- | ----------- | --------- | ---------- | ------- | -----
|
||||
1 | 142 | 238 | 95.6 | 2987 | 132
|
||||
2 | 284 | 479 | 187 | 5658 | 259
|
||||
4 | 569 | 948 | 374 | 10509 | 511
|
||||
8 | 1131 | 1886 | 744 | 17822 | 959
|
||||
1 | 142 | 238 | 95.6 | 2987 | 154
|
||||
2 | 284 | 479 | 187 | 5658 | 295
|
||||
4 | 569 | 948 | 374 | 10509 | 584
|
||||
8 | 1131 | 1886 | 744 | 17822 | 1081
|
||||
|
||||
**Training real data**
|
||||
|
||||
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
|
||||
---- | ----------- | --------- | ---------- | ------- | -----
|
||||
1 | 142 | 239 | 95.5 | 2890 | 132
|
||||
2 | 278 | 468 | 187 | 4448 | 245
|
||||
4 | 551 | 938 | 373 | 7105 | 466
|
||||
8 | 1079 | 1802 | 721 | N/A | 794
|
||||
1 | 142 | 239 | 95.5 | 2890 | 154
|
||||
2 | 278 | 468 | 187 | 4448 | 284
|
||||
4 | 551 | 938 | 373 | 7105 | 534
|
||||
8 | 1079 | 1802 | 721 | N/A | 898
|
||||
|
||||
Training AlexNet with real data on 8 GPUs was excluded from the graph and table
|
||||
above due to it maxing out the input pipeline.
|
||||
@ -145,19 +145,19 @@ The results below are all with a batch size of 32.
|
||||
|
||||
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | VGG16
|
||||
---- | ----------- | --------- | ---------- | -----
|
||||
1 | 128 | 210 | 85.3 | 124
|
||||
2 | 259 | 412 | 166 | 241
|
||||
4 | 520 | 827 | 330 | 470
|
||||
8 | 995 | 1623 | 643 | 738
|
||||
1 | 128 | 210 | 85.3 | 144
|
||||
2 | 259 | 412 | 166 | 281
|
||||
4 | 520 | 827 | 330 | 549
|
||||
8 | 995 | 1623 | 643 | 820
|
||||
|
||||
**Training real data**
|
||||
|
||||
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | VGG16
|
||||
---- | ----------- | --------- | ---------- | -----
|
||||
1 | 130 | 208 | 85.0 | 124
|
||||
2 | 257 | 403 | 163 | 221
|
||||
4 | 507 | 814 | 325 | 401
|
||||
8 | 966 | 1525 | 641 | 619
|
||||
1 | 130 | 208 | 85.0 | 144
|
||||
2 | 257 | 403 | 163 | 253
|
||||
4 | 507 | 814 | 325 | 457
|
||||
8 | 966 | 1525 | 641 | 690
|
||||
|
||||
## Details for Google Compute Engine (NVIDIA® Tesla® K80)
|
||||
|
||||
@ -198,19 +198,19 @@ The configuration used for each model was `variable_update` equal to
|
||||
|
||||
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
|
||||
---- | ----------- | --------- | ---------- | ------- | -----
|
||||
1 | 30.5 | 56.8 | 20.8 | 656 | 30.3
|
||||
2 | 57.8 | 107 | 39.1 | 1210 | 56.2
|
||||
4 | 116 | 212 | 77.2 | 2330 | 106
|
||||
8 | 227 | 419 | 151 | 4640 | 222
|
||||
1 | 30.5 | 56.8 | 20.8 | 656 | 35.4
|
||||
2 | 57.8 | 107 | 39.1 | 1209 | 64.8
|
||||
4 | 116 | 212 | 77.2 | 2328 | 120
|
||||
8 | 227 | 419 | 151 | 4640 | 234
|
||||
|
||||
**Training real data**
|
||||
|
||||
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
|
||||
---- | ----------- | --------- | ---------- | ------- | -----
|
||||
1 | 30.6 | 56.7 | 20.7 | 639 | 30.2
|
||||
2 | 58.4 | 107 | 39.0 | 1136 | 55.5
|
||||
4 | 115 | 211 | 77.3 | 2067 | 106
|
||||
8 | 225 | 422 | 151 | 4056 | 213
|
||||
1 | 30.6 | 56.7 | 20.7 | 639 | 34.2
|
||||
2 | 58.4 | 107 | 39.0 | 1136 | 62.9
|
||||
4 | 115 | 211 | 77.3 | 2067 | 118
|
||||
8 | 225 | 422 | 151 | 4056 | 230
|
||||
|
||||
### Other Results
|
||||
|
||||
@ -279,19 +279,19 @@ VGG16 | parameter_server | gpu
|
||||
|
||||
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
|
||||
---- | ----------- | --------- | ---------- | ------- | -----
|
||||
1 | 30.8 | 56.3 | 20.9 | 684 | 32.4
|
||||
2 | 58.7 | 108 | 39.3 | 1244 | 61.5
|
||||
4 | 117 | 217 | 79.1 | 2479 | 123
|
||||
8 | 230 | 419 | 156 | 4853 | 234
|
||||
1 | 30.8 | 56.3 | 20.9 | 684 | 36.3
|
||||
2 | 58.7 | 108 | 39.3 | 1244 | 69.4
|
||||
4 | 117 | 217 | 79.1 | 2479 | 141
|
||||
8 | 230 | 419 | 156 | 4853 | 260
|
||||
|
||||
**Training real data**
|
||||
|
||||
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
|
||||
---- | ----------- | --------- | ---------- | ------- | -----
|
||||
1 | 30.5 | 56.0 | 20.6 | 674 | 32.0
|
||||
2 | 58.7 | 107 | 39.0 | 1227 | 61.0
|
||||
4 | 118 | 205 | 77.9 | 2201 | 120
|
||||
8 | 228 | 405 | 152 | N/A | 191
|
||||
1 | 30.5 | 56.0 | 20.6 | 674 | 36.3
|
||||
2 | 59.0 | 107 | 39.0 | 1227 | 67.5
|
||||
4 | 118 | 205 | 77.9 | 2201 | 136
|
||||
8 | 228 | 405 | 152 | N/A | 242
|
||||
|
||||
Training AlexNet with real data on 8 GPUs was excluded from the graph and table
|
||||
above due to our EFS setup not providing enough throughput.
|
||||
@ -393,63 +393,17 @@ GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
|
||||
32 | 820 | 1265
|
||||
64 | 1608 | 2623
|
||||
|
||||
## Appendix
|
||||
|
||||
### Executing benchmark tests
|
||||
## Methodology
|
||||
|
||||
The [benchmark code](https://github.com/tensorflow/benchmarks/tree/master/scripts/tf_cnn_benchmarks)
|
||||
was created to be used for benchmarking TensorFlow as well as used as a tool to
|
||||
test hardware platforms. Techniques used in the benchmark scripts are detailed
|
||||
in @{$performance_models$High-Performance Models}.
|
||||
This [script](https://github.com/tensorflow/benchmarks/tree/master/scripts/tf_cnn_benchmarks)
|
||||
was run on the various platforms to generate the above results.
|
||||
@{$performance_models$High-Performance Models} details techniques in the script
|
||||
along with examples of how to execute the script.
|
||||
|
||||
There are two ways to execute the benchmark code:
|
||||
|
||||
1. Execute [tf_cnn_benchmarks.py](https://github.com/tensorflow/benchmarks/tree/master/scripts/tf_cnn_benchmarks/tf_cnn_benchmarks.py)
|
||||
directly.
|
||||
2. Utilize the [scripts](https://github.com/tensorflow/benchmarks/tree/master/scripts/tf_cnn_benchmarks/main.py)
|
||||
that helps pick the correct config for each platform executes
|
||||
`tf_cnn_benchmarks.py`.
|
||||
|
||||
The wrapper is suggested as a starting point. Then investigate the variety of
|
||||
options available in `tf_cnn_benchmarks.py`. Below are a couple examples of
|
||||
using the wrapper.
|
||||
|
||||
**Single Server**
|
||||
This example illustrates training ResNet-50 on a single instance with 8 GPUs.
|
||||
The `system` flag is used to determine the optimal configuration. The
|
||||
supported values are gce, aws, and dgx1. If `system` is not passed, the best
|
||||
config for the most widely available hardware is used.
|
||||
|
||||
```bash
|
||||
python main.py --model=resnet50 --num_gpus=8
|
||||
python main.py --system=aws --model=resnet50 --num_gpus=8
|
||||
```
|
||||
|
||||
**Distributed**
|
||||
This example illustrates training ResNet-50 on 2 hosts, e.g. host_0 (10.0.0.1)
|
||||
and host_1 (10.0.0.2), with 8 GPUs each on AWS (Amazon EC2).
|
||||
|
||||
```bash
|
||||
# Run the following commands on host_0 (10.0.0.1):
|
||||
$ python main.py --system=aws --model=resnet50 --job_name=worker
|
||||
--hosts=10.0.0.1,10.0.0.2 --task_index=0
|
||||
|
||||
$ python main.py --system=aws --model=resnet50 --job_name=ps
|
||||
--hosts=10.0.0.1,10.0.0.2 --task_index=0
|
||||
|
||||
# Run the following commands on host_1 (10.0.0.2):
|
||||
$ python main.py --system=aws --model=resnet50 --job_name=worker
|
||||
--hosts=10.0.0.1,10.0.0.2 --task_index=1
|
||||
|
||||
$ python main.py --system=aws --model=resnet50 --job_name=ps
|
||||
--hosts=10.0.0.1,10.0.0.2 --task_index=1
|
||||
```
|
||||
|
||||
### Methodology
|
||||
|
||||
Unless otherwise stated, each test is run 5 times and then the times are
|
||||
averaged together. GPUs are run in their default state on the given platform.
|
||||
For NVIDIA® Tesla® K80 this means leaving on [GPU
|
||||
Boost](https://devblogs.nvidia.com/parallelforall/increase-performance-gpu-boost-k80-autoboost/)
|
||||
unless it has been turned off by the provider. For a given test, 10 warmup steps
|
||||
are done and then the next 100 steps are averaged.
|
||||
In order to create results that are as repeatable as possible, each test was run
|
||||
5 times and then the times were averaged together. GPUs are run in their default
|
||||
state on the given platform. For NVIDIA® Tesla® K80 this means leaving on [GPU
|
||||
Boost](https://devblogs.nvidia.com/parallelforall/increase-performance-gpu-boost-k80-autoboost/).
|
||||
For each test, 10 warmup steps are done and then the next 100 steps are
|
||||
averaged.
|
||||
|
@ -9,7 +9,7 @@ deeper with techniques detailed in @{$performance_models$High-Performance Models
|
||||
practices for optimizing your TensorFlow code.
|
||||
|
||||
* @{$performance_models$High-Performance Models}, which contains a collection
|
||||
advanced techniques to build highly scalable models targeting different
|
||||
of advanced techniques to build highly scalable models targeting different
|
||||
system types and network topologies.
|
||||
|
||||
* @{$benchmarks$Benchmarks}, which contains a collection of benchmark
|
||||
|
@ -14,8 +14,8 @@ input pipeline issues and best practices. We found that using @{tf.FIFOQueue}
|
||||
and @{tf.train.queue_runner} could not saturate multiple current generation GPUs
|
||||
when using large inputs and processing with higher samples per second, such
|
||||
as training ImageNet with [AlexNet](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf).
|
||||
This is due to the the use of Python threads as its underlying implementation.
|
||||
The overhead of Python threads is too large.
|
||||
This is due to the use of Python threads as its underlying implementation. The
|
||||
overhead of Python threads is too large.
|
||||
|
||||
Another approach, which we have implemented in the
|
||||
[scripts](https://github.com/tensorflow/benchmarks/tree/master/scripts/tf_cnn_benchmarks),
|
||||
@ -327,3 +327,96 @@ free.
|
||||
The downside is that all the weights read are from the previous training step.
|
||||
So it is a different algorithm from SGD. But it is possible to improve its
|
||||
convergence by adjusting learning rate and other hyperparameters.
|
||||
|
||||
## Executing the script
|
||||
|
||||
This section lists the core command line arguments and a few basic examples for
|
||||
executing the main script
|
||||
([tf_cnn_benchmarks.py](https://github.com/tensorflow/benchmarks/tree/master/scripts/tf_cnn_benchmarks/tf_cnn_benchmarks.py)).
|
||||
|
||||
> Note: `tf_cnn_benchmarks.py` uses the config `force_gpu_compatible`,
|
||||
> which was introduced after TensorFlow 1.1. Until TensorFlow 1.2 is released
|
||||
> building from source is advised.
|
||||
|
||||
#### Base command line arguments
|
||||
|
||||
* **`model`**: Model to use, e.g. `resnet50`, `inception3`, `vgg16`, and
|
||||
`alexnet`.
|
||||
* **`num_gpus`**: Number of GPUs to use.
|
||||
* **`data_dir`**: Path to data to process. If not set, synthetic data is used.
|
||||
To use Imagenet data use these
|
||||
[instructions(https://github.com/tensorflow/models/tree/master/inception#getting-started)
|
||||
as a starting point.
|
||||
* **`batch_size`**: Batch size for each GPU.
|
||||
* **`variable_update`**: The method for managing variables: `parameter_server`
|
||||
,`replicated`, `distributed_replicated`, `independent`
|
||||
* **`local_parameter_device`**: Device to use as parameter server: `cpu` or
|
||||
`gpu`.
|
||||
|
||||
#### Single instance examples
|
||||
|
||||
```bash
|
||||
# VGG16 training ImageNet with 8 GPUs using arguments that optimize for
|
||||
# Google Compute Engine.
|
||||
python tf_cnn_benchmarks.py --local_parameter_device=cpu --num_gpus=8 \
|
||||
--batch_size=32 --model=vgg16 --data_dir=/home/ubuntu/imagenet/train \
|
||||
--variable_update=parameter_server --nodistortions
|
||||
|
||||
# VGG16 training synthetic ImageNet data with 8 GPUs using arguments that
|
||||
# optimize for the NVIDIA DGX-1.
|
||||
python tf_cnn_benchmarks.py --local_parameter_device=gpu --num_gpus=8 \
|
||||
--batch_size=64 --model=vgg16 --variable_update=replicated --use_nccl=True
|
||||
|
||||
# VGG16 training ImageNet data with 8 GPUs using arguments that optimize for
|
||||
# Amazon EC2.
|
||||
python tf_cnn_benchmarks.py --local_parameter_device=gpu --num_gpus=8 \
|
||||
--batch_size=64 --model=vgg16 --variable_update=parameter_server
|
||||
|
||||
# ResNet-50 training ImageNet data with 8 GPUs using arguments that optimize for
|
||||
# Amazon EC2.
|
||||
python tf_cnn_benchmarks.py --local_parameter_device=gpu --num_gpus=8 \
|
||||
--batch_size=64 --model=resnet50 --variable_update=replicated --use_nccl=False
|
||||
|
||||
```
|
||||
|
||||
#### Distributed command line arguments
|
||||
|
||||
* **`ps_hosts`**: Comma separated list of hosts to use as parameter servers
|
||||
in the format of ```<host>:port```, e.g. ```10.0.0.2:50000```.
|
||||
* **`worker_hosts`**: Comma separated list of hosts to use as workers in the
|
||||
format of ```<host>:port```, e.g. ```10.0.0.2:50001```.
|
||||
* **`task_index`**: Index of the host in the list of `ps_hosts` or
|
||||
`worker_hosts` being started.
|
||||
* **`job_name`**: Type of job, e.g `ps` or `worker`
|
||||
|
||||
#### Distributed examples
|
||||
|
||||
Below is an example of training ResNet-50 on 2 hosts: host_0 (10.0.0.1) and
|
||||
host_1 (10.0.0.2). The example uses synthetic data. To use real data pass the
|
||||
`--data_dir` argument.
|
||||
|
||||
```bash
|
||||
# Run the following commands on host_0 (10.0.0.1):
|
||||
python tf_cnn_benchmarks.py --local_parameter_device=gpu --num_gpus=8 \
|
||||
--batch_size=64 --model=resnet50 --variable_update=distributed_replicated \
|
||||
--job_name=worker --ps_hosts=10.0.0.1:50000,10.0.0.2:50000 \
|
||||
--worker_hosts=10.0.0.1:50001,10.0.0.2:50001 --task_index=0
|
||||
|
||||
python tf_cnn_benchmarks.py --local_parameter_device=gpu --num_gpus=8 \
|
||||
--batch_size=64 --model=resnet50 --variable_update=distributed_replicated \
|
||||
--job_name=ps --ps_hosts=10.0.0.1:50000,10.0.0.2:50000 \
|
||||
--worker_hosts=10.0.0.1:50001,10.0.0.2:50001 --task_index=0
|
||||
|
||||
|
||||
# Run the following commands on host_1 (10.0.0.2):
|
||||
python tf_cnn_benchmarks.py --local_parameter_device=gpu --num_gpus=8 \
|
||||
--batch_size=64 --model=resnet50 --variable_update=distributed_replicated \
|
||||
--job_name=worker --ps_hosts=10.0.0.1:50000,10.0.0.2:50000 \
|
||||
--worker_hosts=10.0.0.1:50001,10.0.0.2:50001 --task_index=1
|
||||
|
||||
python tf_cnn_benchmarks.py --local_parameter_device=gpu --num_gpus=8 \
|
||||
--batch_size=64 --model=resnet50 --variable_update=distributed_replicated \
|
||||
--job_name=ps --ps_hosts=10.0.0.1:50000,10.0.0.2:50000 \
|
||||
--worker_hosts=10.0.0.1:50001,10.0.0.2:50001 --task_index=1
|
||||
|
||||
```
|
||||
|
@ -5,7 +5,7 @@ in the way described in the @{$variables$Variables HowTo}.
|
||||
But when building complex models you often need to share large sets of
|
||||
variables and you might want to initialize all of them in one place.
|
||||
This tutorial shows how this can be done using `tf.variable_scope()` and
|
||||
the `tf.get_variable()`.
|
||||
`tf.get_variable()`.
|
||||
|
||||
## The Problem
|
||||
|
||||
@ -368,6 +368,6 @@ sequence-to-sequence models.
|
||||
|
||||
File | What's in it?
|
||||
--- | ---
|
||||
`models/tutorials/image/cifar10/cifar10.py` | Model for detecting objects in images.
|
||||
`models/tutorials/rnn/rnn_cell.py` | Cell functions for recurrent neural networks.
|
||||
`models/tutorials/rnn/seq2seq.py` | Functions for building sequence-to-sequence models.
|
||||
`tutorials/image/cifar10/cifar10.py` | Model for detecting objects in images.
|
||||
`tutorials/rnn/rnn_cell.py` | Cell functions for recurrent neural networks.
|
||||
`tutorials/rnn/seq2seq.py` | Functions for building sequence-to-sequence models.
|
||||
|
@ -83,7 +83,7 @@ for details. It consists of 1,068,298 learnable parameters and requires about
|
||||
## Code Organization
|
||||
|
||||
The code for this tutorial resides in
|
||||
[`tensorflow_models/tutorials/image/cifar10/`](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10/).
|
||||
[`models/tutorials/image/cifar10/`](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10/).
|
||||
|
||||
File | Purpose
|
||||
--- | ---
|
||||
|
@ -348,12 +348,6 @@ class BaseDebugWrapperSession(session.SessionInterface):
|
||||
|
||||
_check_type(sess, session.BaseSession)
|
||||
|
||||
# TODO(cais): Remove this check once tfdbg is integrated with GrpcSession.
|
||||
if sess.sess_str:
|
||||
raise NotImplementedError(
|
||||
"Non-DirectSession support is not available from TensorFlow "
|
||||
"Debugger yet (sess_str=%s)" % sess.sess_str)
|
||||
|
||||
# The session being wrapped.
|
||||
self._sess = sess
|
||||
self._thread_name_filter_pattern = (re.compile(thread_name_filter)
|
||||
|
@ -384,18 +384,6 @@ class DebugWrapperSessionTest(test_util.TensorFlowTestCase):
|
||||
["a_init", "b_init"],
|
||||
[datum.node_name for datum in dump.dumped_tensor_data])
|
||||
|
||||
def testUsingNonDirectSessionRaisesNotImplementedError(self):
|
||||
# TODO(cais): Remove this test once tfdbg is integrated with GrpcSession.
|
||||
fake_non_direct_session = session.Session()
|
||||
fake_non_direct_session._target = "foo"
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
NotImplementedError,
|
||||
r"Non-DirectSession support is not available from TensorFlow Debugger "
|
||||
r"yet \(sess_str=foo\)"):
|
||||
TestDebugWrapperSession(
|
||||
fake_non_direct_session, self._dump_root, self._observer)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
|
@ -139,6 +139,82 @@ from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
def make_input_layer(features,
|
||||
feature_columns,
|
||||
weight_collections=None,
|
||||
trainable=True):
|
||||
"""Returns a dense `Tensor` as input layer based on given `feature_columns`.
|
||||
|
||||
Generally a single example in training data is described with FeatureColumns.
|
||||
At the first layer of the model, this column oriented data should be converted
|
||||
to a single `Tensor`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
price = numeric_column('price')
|
||||
keywords_embedded = embedding_column(
|
||||
categorical_column_with_hash_bucket("keywords", 10K), dimensions=16)
|
||||
all_feature_columns = [price, keywords_embedded, ...]
|
||||
dense_tensor = make_input_layer(features, all_feature_columns)
|
||||
for units in [128, 64, 32]:
|
||||
dense_tensor = tf.layers.dense(dense_tensor, units, tf.nn.relu)
|
||||
prediction = tf.layers.dense(dense_tensor, 1)
|
||||
```
|
||||
|
||||
Args:
|
||||
features: A mapping from key to tensors. `FeatureColumn`s look up via these
|
||||
keys. For example `numeric_column('price') will look at 'price' key in
|
||||
this dict. Values can be a `SparseTensor` or a `Tensor` depends on
|
||||
corresponding `FeatureColumn`.
|
||||
feature_columns: An iterable containing all the `FeatureColumn`s. All items
|
||||
should be instances of classes derived from `_DenseColumn` such as
|
||||
`numeric_column`, `embedding_column`, `bucketized_column`,
|
||||
`indicator_column`. If you have categorical features, you can wrap them
|
||||
with with an `embedding_column` or `indicator_column`.
|
||||
weight_collections: A list of collection names to which the Variable will be
|
||||
added. Note that, variables will also be added to collections
|
||||
`tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`.
|
||||
trainable: If `True` also add the variable to the graph collection
|
||||
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
|
||||
|
||||
Returns:
|
||||
A `Tensor` which represents input layer of a model. Its shape
|
||||
is (batch_size, first_layer_dimension) and its dtype is `float32`.
|
||||
first_layer_dimension is determined based on given `feature_columns`.
|
||||
|
||||
Raises:
|
||||
ValueError: if an item in `feature_columns` is not a `_DenseColumn`.
|
||||
"""
|
||||
_check_feature_columns(feature_columns)
|
||||
for column in feature_columns:
|
||||
if not isinstance(column, _DenseColumn):
|
||||
raise ValueError(
|
||||
'Items of feature_columns must be a _DenseColumn. '
|
||||
'You can wrap a categorical column with an '
|
||||
'embedding_column or indicator_column. Given: {}'.format(column))
|
||||
weight_collections = list(weight_collections or [])
|
||||
if ops.GraphKeys.GLOBAL_VARIABLES not in weight_collections:
|
||||
weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||
if ops.GraphKeys.MODEL_VARIABLES not in weight_collections:
|
||||
weight_collections.append(ops.GraphKeys.MODEL_VARIABLES)
|
||||
with variable_scope.variable_scope(
|
||||
None, default_name='make_input_layer', values=features.values()):
|
||||
builder = _LazyBuilder(features)
|
||||
output_tensors = []
|
||||
for column in sorted(feature_columns, key=lambda x: x.name):
|
||||
with variable_scope.variable_scope(None, default_name=column.name):
|
||||
tensor = column._get_dense_tensor( # pylint: disable=protected-access
|
||||
builder,
|
||||
weight_collections=weight_collections,
|
||||
trainable=trainable)
|
||||
num_elements = column._variable_shape.num_elements() # pylint: disable=protected-access
|
||||
batch_size = array_ops.shape(tensor)[0]
|
||||
tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements))
|
||||
output_tensors.append(tensor)
|
||||
return array_ops.concat(output_tensors, 1)
|
||||
|
||||
|
||||
def make_linear_model(features,
|
||||
feature_columns,
|
||||
units=1,
|
||||
@ -156,10 +232,21 @@ def make_linear_model(features,
|
||||
while `make_input_layer` explicitly requires wrapping each of them with an
|
||||
`embedding_column` or an `indicator_column`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
price = numeric_column('price')
|
||||
price_buckets = bucketized_column(price, boundaries=[0., 10., 100., 1000.])
|
||||
keywords = categorical_column_with_hash_bucket("keywords", 10K)
|
||||
all_feature_columns = [price_buckets, keywords, ...]
|
||||
prediction = make_linear_model(features, all_feature_columns)
|
||||
```
|
||||
|
||||
Args:
|
||||
features: A mapping from key to tensors. 'string' key means a base feature.
|
||||
It can have `_FeatureColumn` as a key too. That means that FeatureColumn
|
||||
is already transformed by the input pipeline.
|
||||
features: A mapping from key to tensors. `FeatureColumn`s look up via these
|
||||
keys. For example `numeric_column('price')` will look at 'price' key in
|
||||
this dict. Values are `Tensor` or `SparseTensor` depending on
|
||||
corresponding `FeatureColumn`.
|
||||
feature_columns: An iterable containing all the FeatureColumns. All items
|
||||
should be instances of classes derived from FeatureColumn.
|
||||
units: units: An integer, dimensionality of the output space. Default
|
||||
@ -191,22 +278,23 @@ def make_linear_model(features,
|
||||
raise ValueError('Items of feature_columns must be either a _DenseColumn '
|
||||
'or _CategoricalColumn. Given: {}'.format(column))
|
||||
weight_collections = list(weight_collections or [])
|
||||
weight_collections += [
|
||||
ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.MODEL_VARIABLES
|
||||
]
|
||||
if ops.GraphKeys.GLOBAL_VARIABLES not in weight_collections:
|
||||
weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||
if ops.GraphKeys.MODEL_VARIABLES not in weight_collections:
|
||||
weight_collections.append(ops.GraphKeys.MODEL_VARIABLES)
|
||||
with variable_scope.variable_scope(
|
||||
None, default_name='make_linear_model', values=features.values()):
|
||||
weigthed_sums = []
|
||||
builder = _LazyBuilder(features)
|
||||
for column in sorted(feature_columns, key=lambda x: x.name):
|
||||
with variable_scope.variable_scope(None, default_name=column.name):
|
||||
if isinstance(column, _DenseColumn):
|
||||
weigthed_sums.append(_create_dense_column_weighted_sum(
|
||||
column, builder, units, weight_collections, trainable))
|
||||
else:
|
||||
if isinstance(column, _CategoricalColumn):
|
||||
weigthed_sums.append(_create_categorical_column_weighted_sum(
|
||||
column, builder, units, sparse_combiner, weight_collections,
|
||||
trainable))
|
||||
else:
|
||||
weigthed_sums.append(_create_dense_column_weighted_sum(
|
||||
column, builder, units, weight_collections, trainable))
|
||||
predictions_no_bias = math_ops.add_n(
|
||||
weigthed_sums, name='weighted_sum_no_bias')
|
||||
bias = variable_scope.get_variable(
|
||||
@ -228,7 +316,8 @@ def numeric_column(key,
|
||||
normalizer_fn=None):
|
||||
"""Represents real valued or numerical features.
|
||||
|
||||
An example:
|
||||
Example:
|
||||
|
||||
```python
|
||||
price = numeric_column('price')
|
||||
all_feature_columns = [price, ...]
|
||||
@ -237,7 +326,7 @@ def numeric_column(key,
|
||||
# or
|
||||
bucketized_price = bucketized_column(price, boundaries=[...])
|
||||
all_feature_columns = [bucketized_price, ...]
|
||||
linear_prediction, _, _ = make_linear_model(features, all_feature_columns)
|
||||
linear_prediction = make_linear_model(features, all_feature_columns)
|
||||
|
||||
```
|
||||
|
||||
@ -291,6 +380,56 @@ def numeric_column(key,
|
||||
normalizer_fn=normalizer_fn)
|
||||
|
||||
|
||||
def bucketized_column(source_column, boundaries):
|
||||
"""Represents discretized dense input.
|
||||
|
||||
Buckets include the left boundary, and exclude the right boundary. Namely,
|
||||
`boundaries=[0., 1., 2.]` generates buckets `(-inf, 0.)`, `[0., 1.)`,
|
||||
`[1., 2.)`, and `[2., +inf)`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
price = numeric_column('price')
|
||||
bucketized_price = bucketized_column(price, boundaries=[...])
|
||||
all_feature_columns = [bucketized_price, ...]
|
||||
linear_prediction = make_linear_model(features, all_feature_columns)
|
||||
|
||||
# or
|
||||
all_feature_columns = [bucketized_price, ...]
|
||||
dense_tensor = make_input_layer(features, all_feature_columns)
|
||||
```
|
||||
|
||||
Args:
|
||||
source_column: A one-dimensional dense column which is generated with
|
||||
`numeric_column`.
|
||||
boundaries: A sorted list or tuple of floats specifying the boundaries.
|
||||
|
||||
Returns:
|
||||
A `_BucketizedColumn`.
|
||||
|
||||
Raises:
|
||||
ValueError: If `source_column` is not a numeric column, or if it is not
|
||||
one-dimensional.
|
||||
ValueError: If `boundaries` is not a sorted list or tuple.
|
||||
"""
|
||||
if not isinstance(source_column, _NumericColumn):
|
||||
raise ValueError(
|
||||
'source_column must be a column generated with numeric_column(). '
|
||||
'Given: {}'.format(source_column))
|
||||
if len(source_column.shape) > 1:
|
||||
raise ValueError(
|
||||
'source_column must be one-dimensional column. '
|
||||
'Given: {}'.format(source_column))
|
||||
if (not boundaries or
|
||||
not (isinstance(boundaries, list) or isinstance(boundaries, tuple))):
|
||||
raise ValueError('boundaries must be a sorted list.')
|
||||
for i in range(len(boundaries) - 1):
|
||||
if boundaries[i] >= boundaries[i + 1]:
|
||||
raise ValueError('boundaries must be a sorted list.')
|
||||
return _BucketizedColumn(source_column, tuple(boundaries))
|
||||
|
||||
|
||||
def categorical_column_with_hash_bucket(key,
|
||||
hash_bucket_size,
|
||||
dtype=dtypes.string):
|
||||
@ -300,11 +439,12 @@ def categorical_column_with_hash_bucket(key,
|
||||
want to distribute your inputs into a finite number of buckets by hashing.
|
||||
output_id = Hash(input_feature_string) % bucket_size
|
||||
|
||||
An example:
|
||||
Example:
|
||||
|
||||
```python
|
||||
keywords = categorical_column_with_hash_bucket("keywords", 10K)
|
||||
linear_prediction, _, _ = make_linear_model(features, all_feature_columns)
|
||||
all_feature_columns = [keywords, ...]
|
||||
linear_prediction = make_linear_model(features, all_feature_columns)
|
||||
|
||||
# or
|
||||
keywords_embedded = embedding_column(keywords, 16)
|
||||
@ -422,7 +562,7 @@ class _DenseColumn(_FeatureColumn):
|
||||
|
||||
@abc.abstractproperty
|
||||
def _variable_shape(self):
|
||||
"""Returns shape of variable which is compatible with _get_dense_tensor."""
|
||||
"""Returns a `TensorShape` of variable compatible with _get_dense_tensor."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
@ -431,6 +571,7 @@ class _DenseColumn(_FeatureColumn):
|
||||
|
||||
The output of this function will be used by model-buildier-functions. For
|
||||
example the pseudo code of `make_input_layer` will be like that:
|
||||
|
||||
```python
|
||||
def make_input_layer(features, feature_columns, ...):
|
||||
outputs = [fc._get_dense_tensor(...) for fc in feature_columns]
|
||||
@ -454,7 +595,7 @@ def _create_dense_column_weighted_sum(
|
||||
builder,
|
||||
weight_collections=weight_collections,
|
||||
trainable=trainable)
|
||||
num_elements = tensor_shape.TensorShape(column._variable_shape).num_elements() # pylint: disable=protected-access
|
||||
num_elements = column._variable_shape.num_elements() # pylint: disable=protected-access
|
||||
batch_size = array_ops.shape(tensor)[0]
|
||||
tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements))
|
||||
weight = variable_scope.get_variable(
|
||||
@ -566,12 +707,15 @@ class _LazyBuilder(object):
|
||||
"""Creates a `_LazyBuilder`.
|
||||
|
||||
Args:
|
||||
features: A mapping from feature column to tensors. A `string` key
|
||||
features: A mapping from feature column to objects that are `Tensor` or
|
||||
`SparseTensor`, or can be converted to same via
|
||||
`sparse_tensor.convert_to_tensor_or_sparse_tensor`. A `string` key
|
||||
signifies a base feature (not-transformed). A `FeatureColumn` key
|
||||
means that this `Tensor` is the output of an existing `FeatureColumn`
|
||||
which can be reused.
|
||||
"""
|
||||
self._columns_to_tensors = features.copy()
|
||||
self._features = features.copy()
|
||||
self._feature_tensors = {}
|
||||
|
||||
def get(self, key):
|
||||
"""Returns a `Tensor` for the given key.
|
||||
@ -591,9 +735,16 @@ class _LazyBuilder(object):
|
||||
ValueError: if key is not found or a transformed `Tensor` cannot be
|
||||
computed.
|
||||
"""
|
||||
if key in self._columns_to_tensors:
|
||||
# Feature_column is already transformed or it's a raw feature.
|
||||
return self._columns_to_tensors[key]
|
||||
if key in self._feature_tensors:
|
||||
# FeatureColumn is already transformed or converted.
|
||||
return self._feature_tensors[key]
|
||||
|
||||
if key in self._features:
|
||||
# FeatureColumn is a raw feature.
|
||||
feature_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
|
||||
self._features[key])
|
||||
self._feature_tensors[key] = feature_tensor
|
||||
return feature_tensor
|
||||
|
||||
if not isinstance(key, (str, _FeatureColumn)):
|
||||
raise TypeError('"key" must be either a "str" or "_FeatureColumn". '
|
||||
@ -604,11 +755,13 @@ class _LazyBuilder(object):
|
||||
|
||||
column = key
|
||||
logging.debug('Transforming feature_column %s.', column)
|
||||
transformed = column._transform_feature(self) # pylint: disable=protected-access
|
||||
# pylint: disable=protected-access
|
||||
transformed = column._transform_feature(self)
|
||||
# pylint: enable=protected-access
|
||||
if transformed is None:
|
||||
raise ValueError('Column {} is not supported.'.format(column.name))
|
||||
self._columns_to_tensors[column] = transformed
|
||||
return self._columns_to_tensors[column]
|
||||
self._feature_tensors[column] = transformed
|
||||
return transformed
|
||||
|
||||
|
||||
def _check_feature_columns(feature_columns):
|
||||
@ -660,7 +813,7 @@ class _NumericColumn(_DenseColumn,
|
||||
|
||||
@property
|
||||
def _variable_shape(self):
|
||||
return self.shape
|
||||
return tensor_shape.TensorShape(self.shape)
|
||||
|
||||
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
|
||||
del weight_collections
|
||||
@ -668,6 +821,74 @@ class _NumericColumn(_DenseColumn,
|
||||
return inputs.get(self)
|
||||
|
||||
|
||||
class _BucketizedColumn(_DenseColumn, _CategoricalColumn,
|
||||
collections.namedtuple('_BucketizedColumn', [
|
||||
'source_column', 'boundaries'])):
|
||||
"""See `bucketized_column`."""
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return '{}_bucketized'.format(self.source_column.name)
|
||||
|
||||
@property
|
||||
def _parse_example_config(self):
|
||||
return self.source_column._parse_example_config # pylint: disable=protected-access
|
||||
|
||||
def _transform_feature(self, inputs):
|
||||
source_tensor = inputs.get(self.source_column)
|
||||
return math_ops._bucketize( # pylint: disable=protected-access
|
||||
source_tensor,
|
||||
boundaries=self.boundaries)
|
||||
|
||||
@property
|
||||
def _variable_shape(self):
|
||||
return tensor_shape.TensorShape(
|
||||
tuple(self.source_column.shape) + (len(self.boundaries) + 1,))
|
||||
|
||||
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
|
||||
del weight_collections
|
||||
del trainable
|
||||
input_tensor = inputs.get(self)
|
||||
return array_ops.one_hot(
|
||||
indices=math_ops.to_int64(input_tensor),
|
||||
depth=len(self.boundaries) + 1,
|
||||
on_value=1.,
|
||||
off_value=0.)
|
||||
|
||||
@property
|
||||
def _num_buckets(self):
|
||||
# By construction, source_column is always one-dimensional.
|
||||
return (len(self.boundaries) + 1) * self.source_column.shape[0]
|
||||
|
||||
def _get_sparse_tensors(self, inputs, weight_collections=None,
|
||||
trainable=None):
|
||||
input_tensor = inputs.get(self)
|
||||
batch_size = array_ops.shape(input_tensor)[0]
|
||||
# By construction, source_column is always one-dimensional.
|
||||
source_dimension = self.source_column.shape[0]
|
||||
|
||||
i1 = array_ops.reshape(
|
||||
array_ops.tile(
|
||||
array_ops.expand_dims(math_ops.range(0, batch_size), 1),
|
||||
[1, source_dimension]),
|
||||
(-1,))
|
||||
i2 = array_ops.tile(math_ops.range(0, source_dimension), [batch_size])
|
||||
# Flatten the bucket indices and unique them across dimensions
|
||||
# E.g. 2nd dimension indices will range from k to 2*k-1 with k buckets
|
||||
bucket_indices = (
|
||||
array_ops.reshape(input_tensor, (-1,)) +
|
||||
(len(self.boundaries) + 1) * i2)
|
||||
|
||||
indices = math_ops.to_int64(array_ops.transpose(array_ops.stack((i1, i2))))
|
||||
dense_shape = math_ops.to_int64(array_ops.stack(
|
||||
[batch_size, source_dimension]))
|
||||
sparse_tensor = sparse_tensor_lib.SparseTensor(
|
||||
indices=indices,
|
||||
values=bucket_indices,
|
||||
dense_shape=dense_shape)
|
||||
return _CategoricalColumn.IdWeightPair(sparse_tensor, None)
|
||||
|
||||
|
||||
def _create_tuple(shape, value):
|
||||
"""Returns a tuple with given shape and filled with value."""
|
||||
if shape:
|
||||
|
@ -65,7 +65,7 @@ class LazyColumnTest(test.TestCase):
|
||||
def _parse_example_config(self):
|
||||
pass
|
||||
|
||||
builder = fc._LazyBuilder(features={'a': constant_op.constant([[2], [3.]])})
|
||||
builder = fc._LazyBuilder(features={'a': [[2], [3.]]})
|
||||
column = TransformCounter()
|
||||
self.assertEqual(0, column.num_transform)
|
||||
builder.get(column)
|
||||
@ -88,7 +88,7 @@ class LazyColumnTest(test.TestCase):
|
||||
def _parse_example_config(self):
|
||||
pass
|
||||
|
||||
builder = fc._LazyBuilder(features={'a': constant_op.constant([[2], [3.]])})
|
||||
builder = fc._LazyBuilder(features={'a': [[2], [3.]]})
|
||||
column = Transformer()
|
||||
self.assertEqual('Output', builder.get(column))
|
||||
self.assertEqual('Output', builder.get(column))
|
||||
@ -108,13 +108,13 @@ class LazyColumnTest(test.TestCase):
|
||||
def _parse_example_config(self):
|
||||
pass
|
||||
|
||||
features = {'a': constant_op.constant([[2], [3.]])}
|
||||
features = {'a': [[2], [3.]]}
|
||||
builder = fc._LazyBuilder(features=features)
|
||||
builder.get(Transformer())
|
||||
self.assertEqual(['a'], list(features.keys()))
|
||||
|
||||
def test_error_if_feature_is_not_found(self):
|
||||
builder = fc._LazyBuilder(features={'a': constant_op.constant([[2], [3.]])})
|
||||
builder = fc._LazyBuilder(features={'a': [[2], [3.]]})
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'bbb is not in features dictionary'):
|
||||
builder.get('bbb')
|
||||
@ -135,7 +135,7 @@ class LazyColumnTest(test.TestCase):
|
||||
def _parse_example_config(self):
|
||||
pass
|
||||
|
||||
builder = fc._LazyBuilder(features={'a': constant_op.constant([[2], [3.]])})
|
||||
builder = fc._LazyBuilder(features={'a': [[2], [3.]]})
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'NotAProperColumn is not supported'):
|
||||
builder.get(NotAProperColumn())
|
||||
@ -145,13 +145,13 @@ class LazyColumnTest(test.TestCase):
|
||||
class NotAFeatureColumn(object):
|
||||
pass
|
||||
|
||||
builder = fc._LazyBuilder(features={'a': constant_op.constant([[2], [3.]])})
|
||||
builder = fc._LazyBuilder(features={'a': [[2], [3.]]})
|
||||
with self.assertRaisesRegexp(
|
||||
TypeError, '"key" must be either a "str" or "_FeatureColumn".'):
|
||||
builder.get(NotAFeatureColumn())
|
||||
|
||||
|
||||
class NumericalColumnTest(test.TestCase):
|
||||
class NumericColumnTest(test.TestCase):
|
||||
|
||||
def test_defaults(self):
|
||||
a = fc.numeric_column('aaa')
|
||||
@ -273,7 +273,7 @@ class NumericalColumnTest(test.TestCase):
|
||||
|
||||
price = fc.numeric_column('price', shape=[2], normalizer_fn=_increment_two)
|
||||
builder = fc._LazyBuilder({
|
||||
'price': constant_op.constant([[1., 2.], [5., 6.]])
|
||||
'price': [[1., 2.], [5., 6.]]
|
||||
})
|
||||
output = builder.get(price)
|
||||
with self.test_session():
|
||||
@ -286,7 +286,7 @@ class NumericalColumnTest(test.TestCase):
|
||||
|
||||
price = fc.numeric_column('price', shape=[2], normalizer_fn=_increment_two)
|
||||
builder = fc._LazyBuilder({
|
||||
'price': constant_op.constant([[1., 2.], [5., 6.]])
|
||||
'price': [[1., 2.], [5., 6.]]
|
||||
})
|
||||
self.assertEqual(builder.get(price), price._get_dense_tensor(builder))
|
||||
|
||||
@ -315,7 +315,7 @@ class NumericalColumnTest(test.TestCase):
|
||||
def test_make_linear_model(self):
|
||||
price = fc.numeric_column('price')
|
||||
with ops.Graph().as_default():
|
||||
features = {'price': constant_op.constant([[1.], [5.]])}
|
||||
features = {'price': [[1.], [5.]]}
|
||||
predictions = fc.make_linear_model(features, [price])
|
||||
bias = get_linear_model_bias()
|
||||
price_var = get_linear_model_column_var(price)
|
||||
@ -327,6 +327,231 @@ class NumericalColumnTest(test.TestCase):
|
||||
self.assertAllClose([[10.], [50.]], predictions.eval())
|
||||
|
||||
|
||||
class BucketizedColumnTest(test.TestCase):
|
||||
|
||||
def test_invalid_source_column_type(self):
|
||||
a = fc.categorical_column_with_hash_bucket('aaa', hash_bucket_size=10)
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
'source_column must be a column generated with numeric_column'):
|
||||
fc.bucketized_column(a, boundaries=[0, 1])
|
||||
|
||||
def test_invalid_source_column_shape(self):
|
||||
a = fc.numeric_column('aaa', shape=[2, 3])
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'source_column must be one-dimensional column'):
|
||||
fc.bucketized_column(a, boundaries=[0, 1])
|
||||
|
||||
def test_invalid_boundaries(self):
|
||||
a = fc.numeric_column('aaa')
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'boundaries must be a sorted list'):
|
||||
fc.bucketized_column(a, boundaries=None)
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'boundaries must be a sorted list'):
|
||||
fc.bucketized_column(a, boundaries=1.)
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'boundaries must be a sorted list'):
|
||||
fc.bucketized_column(a, boundaries=[1, 0])
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'boundaries must be a sorted list'):
|
||||
fc.bucketized_column(a, boundaries=[1, 1])
|
||||
|
||||
def test_name(self):
|
||||
a = fc.numeric_column('aaa', dtype=dtypes.int32)
|
||||
b = fc.bucketized_column(a, boundaries=[0, 1])
|
||||
self.assertEqual('aaa_bucketized', b.name)
|
||||
|
||||
def test_parse_config(self):
|
||||
a = fc.numeric_column('aaa', shape=[2], dtype=dtypes.int32)
|
||||
b = fc.bucketized_column(a, boundaries=[0, 1])
|
||||
self.assertEqual({
|
||||
'aaa': parsing_ops.FixedLenFeature((2,), dtype=dtypes.int32)
|
||||
}, b._parse_example_config)
|
||||
|
||||
def test_variable_shape(self):
|
||||
a = fc.numeric_column('aaa', shape=[2], dtype=dtypes.int32)
|
||||
b = fc.bucketized_column(a, boundaries=[0, 1])
|
||||
# Column 'aaa` has shape [2] times three buckets -> variable_shape=[2, 3].
|
||||
self.assertAllEqual((2, 3), b._variable_shape)
|
||||
|
||||
def test_num_buckets(self):
|
||||
a = fc.numeric_column('aaa', shape=[2], dtype=dtypes.int32)
|
||||
b = fc.bucketized_column(a, boundaries=[0, 1])
|
||||
# Column 'aaa` has shape [2] times three buckets -> num_buckets=6.
|
||||
self.assertEqual(6, b._num_buckets)
|
||||
|
||||
def test_parse_example(self):
|
||||
price = fc.numeric_column('price', shape=[2])
|
||||
bucketized_price = fc.bucketized_column(price, boundaries=[0, 50])
|
||||
data = example_pb2.Example(features=feature_pb2.Features(
|
||||
feature={
|
||||
'price':
|
||||
feature_pb2.Feature(float_list=feature_pb2.FloatList(
|
||||
value=[20., 110.]))
|
||||
}))
|
||||
features = parsing_ops.parse_example(
|
||||
serialized=[data.SerializeToString()],
|
||||
features=bucketized_price._parse_example_config)
|
||||
self.assertIn('price', features)
|
||||
with self.test_session():
|
||||
self.assertAllEqual([[20., 110.]], features['price'].eval())
|
||||
|
||||
def test_transform_feature(self):
|
||||
price = fc.numeric_column('price', shape=[2])
|
||||
bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
|
||||
with ops.Graph().as_default():
|
||||
builder = fc._LazyBuilder({
|
||||
'price': [[-1., 1.], [5., 6.]]
|
||||
})
|
||||
transformed_tensor = builder.get(bucketized_price)
|
||||
with _initialized_session():
|
||||
self.assertAllEqual([[0, 1], [3, 4]], transformed_tensor.eval())
|
||||
|
||||
def test_get_dense_tensor_one_input_value(self):
|
||||
"""Tests _get_dense_tensor() for input with shape=[1]."""
|
||||
price = fc.numeric_column('price', shape=[1])
|
||||
bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
|
||||
with ops.Graph().as_default():
|
||||
builder = fc._LazyBuilder({
|
||||
'price': [[-1.], [1.], [5.], [6.]]
|
||||
})
|
||||
with _initialized_session():
|
||||
bucketized_price_tensor = bucketized_price._get_dense_tensor(builder)
|
||||
self.assertAllClose(
|
||||
# One-hot tensor.
|
||||
[[[1., 0., 0., 0., 0.]],
|
||||
[[0., 1., 0., 0., 0.]],
|
||||
[[0., 0., 0., 1., 0.]],
|
||||
[[0., 0., 0., 0., 1.]]],
|
||||
bucketized_price_tensor.eval())
|
||||
|
||||
def test_get_dense_tensor_two_input_values(self):
|
||||
"""Tests _get_dense_tensor() for input with shape=[2]."""
|
||||
price = fc.numeric_column('price', shape=[2])
|
||||
bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
|
||||
with ops.Graph().as_default():
|
||||
builder = fc._LazyBuilder({
|
||||
'price': [[-1., 1.], [5., 6.]]
|
||||
})
|
||||
with _initialized_session():
|
||||
bucketized_price_tensor = bucketized_price._get_dense_tensor(builder)
|
||||
self.assertAllClose(
|
||||
# One-hot tensor.
|
||||
[[[1., 0., 0., 0., 0.], [0., 1., 0., 0., 0.]],
|
||||
[[0., 0., 0., 1., 0.], [0., 0., 0., 0., 1.]]],
|
||||
bucketized_price_tensor.eval())
|
||||
|
||||
def test_get_sparse_tensors_one_input_value(self):
|
||||
"""Tests _get_sparse_tensors() for input with shape=[1]."""
|
||||
price = fc.numeric_column('price', shape=[1])
|
||||
bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
|
||||
with ops.Graph().as_default():
|
||||
builder = fc._LazyBuilder({
|
||||
'price': [[-1.], [1.], [5.], [6.]]
|
||||
})
|
||||
with _initialized_session() as sess:
|
||||
id_weight_pair = bucketized_price._get_sparse_tensors(builder)
|
||||
self.assertIsNone(id_weight_pair.weight_tensor)
|
||||
id_tensor_value = sess.run(id_weight_pair.id_tensor)
|
||||
self.assertAllEqual(
|
||||
[[0, 0], [1, 0], [2, 0], [3, 0]], id_tensor_value.indices)
|
||||
self.assertAllEqual([0, 1, 3, 4], id_tensor_value.values)
|
||||
self.assertAllEqual([4, 1], id_tensor_value.dense_shape)
|
||||
|
||||
def test_get_sparse_tensors_two_input_values(self):
|
||||
"""Tests _get_sparse_tensors() for input with shape=[2]."""
|
||||
price = fc.numeric_column('price', shape=[2])
|
||||
bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
|
||||
with ops.Graph().as_default():
|
||||
builder = fc._LazyBuilder({
|
||||
'price': [[-1., 1.], [5., 6.]]
|
||||
})
|
||||
with _initialized_session() as sess:
|
||||
id_weight_pair = bucketized_price._get_sparse_tensors(builder)
|
||||
self.assertIsNone(id_weight_pair.weight_tensor)
|
||||
id_tensor_value = sess.run(id_weight_pair.id_tensor)
|
||||
self.assertAllEqual(
|
||||
[[0, 0], [0, 1], [1, 0], [1, 1]], id_tensor_value.indices)
|
||||
# Values 0-4 correspond to the first column of the input price.
|
||||
# Values 5-9 correspond to the second column of the input price.
|
||||
self.assertAllEqual([0, 6, 3, 9], id_tensor_value.values)
|
||||
self.assertAllEqual([2, 2], id_tensor_value.dense_shape)
|
||||
|
||||
def test_sparse_tensor_input_not_supported(self):
|
||||
price = fc.numeric_column('price')
|
||||
bucketized_price = fc.bucketized_column(price, boundaries=[0, 1])
|
||||
builder = fc._LazyBuilder({
|
||||
'price':
|
||||
sparse_tensor.SparseTensor(
|
||||
indices=[[0, 0]], values=[0.3], dense_shape=[1, 1])
|
||||
})
|
||||
with self.assertRaisesRegexp(ValueError, 'must be a Tensor'):
|
||||
bucketized_price._transform_feature(builder)
|
||||
|
||||
def test_deep_copy(self):
|
||||
a = fc.numeric_column('aaa', shape=[2])
|
||||
a_bucketized = fc.bucketized_column(a, boundaries=[0, 1])
|
||||
a_bucketized_copy = copy.deepcopy(a_bucketized)
|
||||
self.assertEqual(a_bucketized_copy.name, 'aaa_bucketized')
|
||||
self.assertAllEqual(a_bucketized_copy._variable_shape, (2, 3))
|
||||
self.assertEqual(a_bucketized_copy.boundaries, (0, 1))
|
||||
|
||||
def test_make_linear_model_one_input_value(self):
|
||||
"""Tests make_linear_model() for input with shape=[1]."""
|
||||
price = fc.numeric_column('price', shape=[1])
|
||||
bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
|
||||
with ops.Graph().as_default():
|
||||
features = {'price': [[-1.], [1.], [5.], [6.]]}
|
||||
predictions = fc.make_linear_model(features, [bucketized_price])
|
||||
bias = get_linear_model_bias()
|
||||
bucketized_price_var = get_linear_model_column_var(bucketized_price)
|
||||
with _initialized_session() as sess:
|
||||
self.assertAllClose([0.], bias.eval())
|
||||
# One weight variable per bucket, all initialized to zero.
|
||||
self.assertAllClose(
|
||||
[[0.], [0.], [0.], [0.], [0.]], bucketized_price_var.eval())
|
||||
self.assertAllClose([[0.], [0.], [0.], [0.]], predictions.eval())
|
||||
sess.run(bucketized_price_var.assign(
|
||||
[[10.], [20.], [30.], [40.], [50.]]))
|
||||
# price -1. is in the 0th bucket, whose weight is 10.
|
||||
# price 1. is in the 1st bucket, whose weight is 20.
|
||||
# price 5. is in the 3rd bucket, whose weight is 40.
|
||||
# price 6. is in the 4th bucket, whose weight is 50.
|
||||
self.assertAllClose([[10.], [20.], [40.], [50.]], predictions.eval())
|
||||
sess.run(bias.assign([1.]))
|
||||
self.assertAllClose([[11.], [21.], [41.], [51.]], predictions.eval())
|
||||
|
||||
def test_make_linear_model_two_input_values(self):
|
||||
"""Tests make_linear_model() for input with shape=[2]."""
|
||||
price = fc.numeric_column('price', shape=[2])
|
||||
bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
|
||||
with ops.Graph().as_default():
|
||||
features = {'price': [[-1., 1.], [5., 6.]]}
|
||||
predictions = fc.make_linear_model(features, [bucketized_price])
|
||||
bias = get_linear_model_bias()
|
||||
bucketized_price_var = get_linear_model_column_var(bucketized_price)
|
||||
with _initialized_session() as sess:
|
||||
self.assertAllClose([0.], bias.eval())
|
||||
# One weight per bucket per input column, all initialized to zero.
|
||||
self.assertAllClose(
|
||||
[[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]],
|
||||
bucketized_price_var.eval())
|
||||
self.assertAllClose([[0.], [0.]], predictions.eval())
|
||||
sess.run(bucketized_price_var.assign(
|
||||
[[10.], [20.], [30.], [40.], [50.],
|
||||
[60.], [70.], [80.], [90.], [100.]]))
|
||||
# 1st example:
|
||||
# price -1. is in the 0th bucket, whose weight is 10.
|
||||
# price 1. is in the 6th bucket, whose weight is 70.
|
||||
# 2nd example:
|
||||
# price 5. is in the 3rd bucket, whose weight is 40.
|
||||
# price 6. is in the 9th bucket, whose weight is 100.
|
||||
self.assertAllClose([[80.], [140.]], predictions.eval())
|
||||
sess.run(bias.assign([1.]))
|
||||
self.assertAllClose([[81.], [141.]], predictions.eval())
|
||||
|
||||
|
||||
class SparseColumnHashedTest(test.TestCase):
|
||||
|
||||
def test_defaults(self):
|
||||
@ -396,15 +621,15 @@ class SparseColumnHashedTest(test.TestCase):
|
||||
float_fc = fc.categorical_column_with_hash_bucket(
|
||||
'a_float', 10, dtype=dtypes.string)
|
||||
int_tensor = sparse_tensor.SparseTensor(
|
||||
values=constant_op.constant([101]),
|
||||
values=[101],
|
||||
indices=[[0, 0]],
|
||||
dense_shape=[1, 1])
|
||||
string_tensor = sparse_tensor.SparseTensor(
|
||||
values=constant_op.constant(['101']),
|
||||
values=['101'],
|
||||
indices=[[0, 0]],
|
||||
dense_shape=[1, 1])
|
||||
float_tensor = sparse_tensor.SparseTensor(
|
||||
values=constant_op.constant([101.]),
|
||||
values=[101.],
|
||||
indices=[[0, 0]],
|
||||
dense_shape=[1, 1])
|
||||
builder = fc._LazyBuilder({
|
||||
@ -520,7 +745,7 @@ class MakeLinearModelTest(test.TestCase):
|
||||
def test_dense_bias(self):
|
||||
price = fc.numeric_column('price')
|
||||
with ops.Graph().as_default():
|
||||
features = {'price': constant_op.constant([[1.], [5.]])}
|
||||
features = {'price': [[1.], [5.]]}
|
||||
predictions = fc.make_linear_model(features, [price])
|
||||
bias = get_linear_model_bias()
|
||||
price_var = get_linear_model_column_var(price)
|
||||
@ -567,10 +792,63 @@ class MakeLinearModelTest(test.TestCase):
|
||||
sess.run(price_var.assign([[10.]]))
|
||||
self.assertAllClose([[1015.], [10065.]], predictions.eval())
|
||||
|
||||
def test_dense_and_sparse_column(self):
|
||||
"""When the column is both dense and sparse, uses sparse tensors."""
|
||||
|
||||
class _DenseAndSparseColumn(fc._DenseColumn, fc._CategoricalColumn):
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return 'dense_and_sparse_column'
|
||||
|
||||
@property
|
||||
def _parse_example_config(self):
|
||||
return {self.name: parsing_ops.VarLenFeature(self.dtype)}
|
||||
|
||||
def _transform_feature(self, inputs):
|
||||
return inputs.get(self.name)
|
||||
|
||||
@property
|
||||
def _variable_shape(self):
|
||||
raise ValueError('Should not use this method.')
|
||||
|
||||
def _get_dense_tensor(self, inputs, weight_collections=None,
|
||||
trainable=None):
|
||||
raise ValueError('Should not use this method.')
|
||||
|
||||
@property
|
||||
def _num_buckets(self):
|
||||
return 4
|
||||
|
||||
def _get_sparse_tensors(self, inputs, weight_collections=None,
|
||||
trainable=None):
|
||||
sp_tensor = sparse_tensor.SparseTensor(
|
||||
indices=[[0, 0], [1, 0], [1, 1]],
|
||||
values=[2, 0, 3],
|
||||
dense_shape=[2, 2])
|
||||
return fc._CategoricalColumn.IdWeightPair(sp_tensor, None)
|
||||
|
||||
dense_and_sparse_column = _DenseAndSparseColumn()
|
||||
with ops.Graph().as_default():
|
||||
sp_tensor = sparse_tensor.SparseTensor(
|
||||
values=['omar', 'stringer', 'marlo'],
|
||||
indices=[[0, 0], [1, 0], [1, 1]],
|
||||
dense_shape=[2, 2])
|
||||
features = {dense_and_sparse_column.name: sp_tensor}
|
||||
predictions = fc.make_linear_model(features, [dense_and_sparse_column])
|
||||
bias = get_linear_model_bias()
|
||||
dense_and_sparse_column_var = get_linear_model_column_var(
|
||||
dense_and_sparse_column)
|
||||
with _initialized_session() as sess:
|
||||
sess.run(dense_and_sparse_column_var.assign(
|
||||
[[10.], [100.], [1000.], [10000.]]))
|
||||
sess.run(bias.assign([5.]))
|
||||
self.assertAllClose([[1005.], [10015.]], predictions.eval())
|
||||
|
||||
def test_dense_multi_output(self):
|
||||
price = fc.numeric_column('price')
|
||||
with ops.Graph().as_default():
|
||||
features = {'price': constant_op.constant([[1.], [5.]])}
|
||||
features = {'price': [[1.], [5.]]}
|
||||
predictions = fc.make_linear_model(features, [price], units=3)
|
||||
bias = get_linear_model_bias()
|
||||
price_var = get_linear_model_column_var(price)
|
||||
@ -607,7 +885,7 @@ class MakeLinearModelTest(test.TestCase):
|
||||
def test_dense_multi_dimension(self):
|
||||
price = fc.numeric_column('price', shape=2)
|
||||
with ops.Graph().as_default():
|
||||
features = {'price': constant_op.constant([[1., 2.], [5., 6.]])}
|
||||
features = {'price': [[1., 2.], [5., 6.]]}
|
||||
predictions = fc.make_linear_model(features, [price])
|
||||
price_var = get_linear_model_column_var(price)
|
||||
with _initialized_session() as sess:
|
||||
@ -635,7 +913,7 @@ class MakeLinearModelTest(test.TestCase):
|
||||
def test_dense_multi_dimension_multi_output(self):
|
||||
price = fc.numeric_column('price', shape=2)
|
||||
with ops.Graph().as_default():
|
||||
features = {'price': constant_op.constant([[1., 2.], [5., 6.]])}
|
||||
features = {'price': [[1., 2.], [5., 6.]]}
|
||||
predictions = fc.make_linear_model(features, [price], units=3)
|
||||
bias = get_linear_model_bias()
|
||||
price_var = get_linear_model_column_var(price)
|
||||
@ -650,7 +928,7 @@ class MakeLinearModelTest(test.TestCase):
|
||||
def test_raises_if_shape_mismatch(self):
|
||||
price = fc.numeric_column('price', shape=2)
|
||||
with ops.Graph().as_default():
|
||||
features = {'price': constant_op.constant([[1.], [5.]])}
|
||||
features = {'price': [[1.], [5.]]}
|
||||
predictions = fc.make_linear_model(features, [price])
|
||||
with _initialized_session():
|
||||
with self.assertRaisesRegexp(Exception, 'requested shape has 4'):
|
||||
@ -659,7 +937,7 @@ class MakeLinearModelTest(test.TestCase):
|
||||
def test_dense_reshaping(self):
|
||||
price = fc.numeric_column('price', shape=[1, 2])
|
||||
with ops.Graph().as_default():
|
||||
features = {'price': constant_op.constant([[[1., 2.]], [[5., 6.]]])}
|
||||
features = {'price': [[[1., 2.]], [[5., 6.]]]}
|
||||
predictions = fc.make_linear_model(features, [price])
|
||||
bias = get_linear_model_bias()
|
||||
price_var = get_linear_model_column_var(price)
|
||||
@ -675,8 +953,8 @@ class MakeLinearModelTest(test.TestCase):
|
||||
price2 = fc.numeric_column('price2')
|
||||
with ops.Graph().as_default():
|
||||
features = {
|
||||
'price1': constant_op.constant([[1., 2.], [5., 6.]]),
|
||||
'price2': constant_op.constant([[3.], [4.]])
|
||||
'price1': [[1., 2.], [5., 6.]],
|
||||
'price2': [[3.], [4.]]
|
||||
}
|
||||
predictions = fc.make_linear_model(features, [price1, price2])
|
||||
bias = get_linear_model_bias()
|
||||
@ -695,7 +973,7 @@ class MakeLinearModelTest(test.TestCase):
|
||||
def test_dense_collection(self):
|
||||
price = fc.numeric_column('price')
|
||||
with ops.Graph().as_default() as g:
|
||||
features = {'price': constant_op.constant([[1.], [5.]])}
|
||||
features = {'price': [[1.], [5.]]}
|
||||
fc.make_linear_model(features, [price], weight_collections=['my-vars'])
|
||||
my_vars = g.get_collection('my-vars')
|
||||
bias = get_linear_model_bias()
|
||||
@ -720,7 +998,7 @@ class MakeLinearModelTest(test.TestCase):
|
||||
def test_dense_trainable_default(self):
|
||||
price = fc.numeric_column('price')
|
||||
with ops.Graph().as_default() as g:
|
||||
features = {'price': constant_op.constant([[1.], [5.]])}
|
||||
features = {'price': [[1.], [5.]]}
|
||||
fc.make_linear_model(features, [price])
|
||||
bias = get_linear_model_bias()
|
||||
price_var = get_linear_model_column_var(price)
|
||||
@ -744,7 +1022,7 @@ class MakeLinearModelTest(test.TestCase):
|
||||
def test_dense_trainable_false(self):
|
||||
price = fc.numeric_column('price')
|
||||
with ops.Graph().as_default() as g:
|
||||
features = {'price': constant_op.constant([[1.], [5.]])}
|
||||
features = {'price': [[1.], [5.]]}
|
||||
fc.make_linear_model(features, [price], trainable=False)
|
||||
trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
|
||||
self.assertEqual([], trainable_vars)
|
||||
@ -796,5 +1074,89 @@ class MakeLinearModelTest(test.TestCase):
|
||||
self.assertIn('wire_cast', my_vars[2].name)
|
||||
|
||||
|
||||
class MakeInputLayerTest(test.TestCase):
|
||||
|
||||
def test_should_be_dense_column(self):
|
||||
with self.assertRaisesRegexp(ValueError, 'must be a _DenseColumn'):
|
||||
fc.make_input_layer(
|
||||
features={'a': [[0]]},
|
||||
feature_columns=[
|
||||
fc.categorical_column_with_hash_bucket('wire_cast', 4)
|
||||
])
|
||||
|
||||
def test_does_not_support_dict_columns(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'Expected feature_columns to be iterable, found dict.'):
|
||||
fc.make_input_layer(
|
||||
features={'a': [[0]]}, feature_columns={'a': fc.numeric_column('a')})
|
||||
|
||||
def test_raises_if_duplicate_name(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'Duplicate feature column name found for columns'):
|
||||
fc.make_input_layer(
|
||||
features={'a': [[0]]},
|
||||
feature_columns=[fc.numeric_column('a'),
|
||||
fc.numeric_column('a')])
|
||||
|
||||
def test_one_column(self):
|
||||
price = fc.numeric_column('price')
|
||||
with ops.Graph().as_default():
|
||||
features = {'price': [[1.], [5.]]}
|
||||
net = fc.make_input_layer(features, [price])
|
||||
with _initialized_session():
|
||||
self.assertAllClose([[1.], [5.]], net.eval())
|
||||
|
||||
def test_multi_dimension(self):
|
||||
price = fc.numeric_column('price', shape=2)
|
||||
with ops.Graph().as_default():
|
||||
features = {'price': [[1., 2.], [5., 6.]]}
|
||||
net = fc.make_input_layer(features, [price])
|
||||
with _initialized_session():
|
||||
self.assertAllClose([[1., 2.], [5., 6.]], net.eval())
|
||||
|
||||
def test_raises_if_shape_mismatch(self):
|
||||
price = fc.numeric_column('price', shape=2)
|
||||
with ops.Graph().as_default():
|
||||
features = {'price': [[1.], [5.]]}
|
||||
net = fc.make_input_layer(features, [price])
|
||||
with _initialized_session():
|
||||
with self.assertRaisesRegexp(Exception, 'requested shape has 4'):
|
||||
net.eval()
|
||||
|
||||
def test_reshaping(self):
|
||||
price = fc.numeric_column('price', shape=[1, 2])
|
||||
with ops.Graph().as_default():
|
||||
features = {'price': [[[1., 2.]], [[5., 6.]]]}
|
||||
net = fc.make_input_layer(features, [price])
|
||||
with _initialized_session():
|
||||
self.assertAllClose([[1., 2.], [5., 6.]], net.eval())
|
||||
|
||||
def test_multi_column(self):
|
||||
price1 = fc.numeric_column('price1', shape=2)
|
||||
price2 = fc.numeric_column('price2')
|
||||
with ops.Graph().as_default():
|
||||
features = {
|
||||
'price1': [[1., 2.], [5., 6.]],
|
||||
'price2': [[3.], [4.]]
|
||||
}
|
||||
net = fc.make_input_layer(features, [price1, price2])
|
||||
with _initialized_session():
|
||||
self.assertAllClose([[1., 2., 3.], [5., 6., 4.]], net.eval())
|
||||
|
||||
def test_column_order(self):
|
||||
price_a = fc.numeric_column('price_a')
|
||||
price_b = fc.numeric_column('price_b')
|
||||
with ops.Graph().as_default():
|
||||
features = {
|
||||
'price_a': [[1.]],
|
||||
'price_b': [[3.]],
|
||||
}
|
||||
net1 = fc.make_input_layer(features, [price_a, price_b])
|
||||
net2 = fc.make_input_layer(features, [price_b, price_a])
|
||||
with _initialized_session():
|
||||
self.assertAllClose([[1., 3.]], net1.eval())
|
||||
self.assertAllClose([[1., 3.]], net2.eval())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -422,14 +422,15 @@ def import_scoped_meta_graph(meta_graph_or_file,
|
||||
graph=None,
|
||||
import_scope=None,
|
||||
input_map=None,
|
||||
unbound_inputs_col_name="unbound_inputs"):
|
||||
unbound_inputs_col_name="unbound_inputs",
|
||||
restore_collections_predicate=(lambda key: True)):
|
||||
"""Recreates a `Graph` saved in a `MetaGraphDef` proto.
|
||||
|
||||
This function takes a `MetaGraphDef` protocol buffer as input. If
|
||||
the argument is a file containing a `MetaGraphDef` protocol buffer ,
|
||||
it constructs a protocol buffer from the file content. The function
|
||||
then adds all the nodes from the `graph_def` field to the
|
||||
current graph, recreates all the collections, and returns a saver
|
||||
current graph, recreates the desired collections, and returns a saver
|
||||
constructed from the `saver_def` field.
|
||||
|
||||
In combination with `export_scoped_meta_graph()`, this function can be used to
|
||||
@ -453,6 +454,10 @@ def import_scoped_meta_graph(meta_graph_or_file,
|
||||
`Tensor` objects. The values of the named input tensors in the imported
|
||||
graph will be re-mapped to the respective `Tensor` values.
|
||||
unbound_inputs_col_name: Collection name for looking up unbound inputs.
|
||||
restore_collections_predicate: a predicate on collection names. A collection
|
||||
named c (i.e whose key is c) will be restored iff
|
||||
1) `restore_collections_predicate(c)` is True, and
|
||||
2) `c != unbound_inputs_col_name`.
|
||||
|
||||
Returns:
|
||||
A dictionary of all the `Variables` imported into the name scope.
|
||||
@ -503,6 +508,8 @@ def import_scoped_meta_graph(meta_graph_or_file,
|
||||
# Don't add unbound_inputs to the new graph.
|
||||
if key == unbound_inputs_col_name:
|
||||
continue
|
||||
if not restore_collections_predicate(key):
|
||||
continue
|
||||
|
||||
kind = col_def.WhichOneof("kind")
|
||||
if kind is None:
|
||||
|
@ -335,6 +335,66 @@ class ScopedMetaGraphTest(test.TestCase):
|
||||
for a, b in zip(orig_meta_graphs, new_meta_graphs):
|
||||
test_util.assert_meta_graph_protos_equal(self, a, b)
|
||||
|
||||
def testScopedImportWithSelectedCollections(self):
|
||||
meta_graph_filename = os.path.join(
|
||||
_TestDir("selected_collections_import"), "meta_graph.pb")
|
||||
|
||||
graph = ops.Graph()
|
||||
# Add a variable to populate two collections. The functionality tested is
|
||||
# not specific to variables, but using variables in the test is convenient.
|
||||
with graph.as_default():
|
||||
variables.Variable(initial_value=1.0, trainable=True)
|
||||
self.assertTrue(
|
||||
all([
|
||||
graph.get_collection(key)
|
||||
for key in
|
||||
[ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.TRAINABLE_VARIABLES]
|
||||
]))
|
||||
meta_graph.export_scoped_meta_graph(
|
||||
filename=meta_graph_filename, graph=graph)
|
||||
|
||||
def _test_import(include_collection_keys, omit_collection_keys):
|
||||
assert set(include_collection_keys).isdisjoint(omit_collection_keys)
|
||||
newgraph = ops.Graph()
|
||||
import_scope = "some_scope_name"
|
||||
|
||||
def _restore_collections_predicate(collection_key):
|
||||
return (collection_key in include_collection_keys and
|
||||
collection_key not in omit_collection_keys)
|
||||
|
||||
meta_graph.import_scoped_meta_graph(
|
||||
meta_graph_filename,
|
||||
graph=newgraph,
|
||||
import_scope=import_scope,
|
||||
restore_collections_predicate=_restore_collections_predicate)
|
||||
collection_values = [
|
||||
newgraph.get_collection(name=key, scope=import_scope)
|
||||
for key in include_collection_keys
|
||||
]
|
||||
self.assertTrue(all(collection_values))
|
||||
collection_values = [
|
||||
newgraph.get_collection(name=key, scope=import_scope)
|
||||
for key in omit_collection_keys
|
||||
]
|
||||
self.assertFalse(any(collection_values))
|
||||
|
||||
_test_import(
|
||||
include_collection_keys=[
|
||||
ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.TRAINABLE_VARIABLES
|
||||
],
|
||||
omit_collection_keys=[])
|
||||
_test_import(
|
||||
include_collection_keys=[ops.GraphKeys.GLOBAL_VARIABLES],
|
||||
omit_collection_keys=[ops.GraphKeys.TRAINABLE_VARIABLES])
|
||||
_test_import(
|
||||
include_collection_keys=[ops.GraphKeys.TRAINABLE_VARIABLES],
|
||||
omit_collection_keys=[ops.GraphKeys.GLOBAL_VARIABLES])
|
||||
_test_import(
|
||||
include_collection_keys=[],
|
||||
omit_collection_keys=[
|
||||
ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.TRAINABLE_VARIABLES
|
||||
])
|
||||
|
||||
def _testScopedExportWithQueue(self, test_dir, exported_filename):
|
||||
graph = ops.Graph()
|
||||
with graph.as_default():
|
||||
|
@ -113,10 +113,9 @@ class DepthwiseConv2DTest(test.TestCase):
|
||||
total_size_1 *= s
|
||||
for s in filter_in_sizes:
|
||||
total_size_2 *= s
|
||||
# Initializes the input tensor with array containing incrementing
|
||||
# numbers from 1.
|
||||
# Initializes the input and filter tensor with numbers incrementing from 1.
|
||||
x1 = [f * 1.0 for f in range(1, total_size_1 + 1)]
|
||||
x2 = [1.0 for f in range(1, total_size_2 + 1)]
|
||||
x2 = [f * 1.0 for f in range(1, total_size_2 + 1)]
|
||||
with self.test_session(use_gpu=use_gpu) as sess:
|
||||
t1 = constant_op.constant(x1, shape=tensor_in_sizes)
|
||||
t1.set_shape(tensor_in_sizes)
|
||||
@ -147,8 +146,9 @@ class DepthwiseConv2DTest(test.TestCase):
|
||||
native_result = sess.run(conv_native)
|
||||
interface_result = sess.run(conv_interface)
|
||||
|
||||
print("diff matrix:",
|
||||
np.amax(np.ravel(native_result) - np.ravel(interface_result)))
|
||||
print("depthwise conv_2d: ", tensor_in_sizes, "*", filter_in_sizes,
|
||||
", stride:", stride, ", padding: ", padding, ", max diff: ",
|
||||
np.amax(np.absolute(native_result - interface_result)))
|
||||
self.assertArrayNear(
|
||||
np.ravel(native_result), np.ravel(interface_result), 1e-5)
|
||||
self.assertShapeEqual(native_result, conv_native)
|
||||
|
@ -88,6 +88,7 @@ class SparseAddTest(test.TestCase):
|
||||
for sp_a in (self._SparseTensorValue_3x3(), self._SparseTensor_3x3()):
|
||||
for sp_b in (self._SparseTensorValue_3x3(), self._SparseTensor_3x3()):
|
||||
sp_sum = sparse_ops.sparse_add(sp_a, sp_b)
|
||||
self.assertAllEqual((3, 3), sp_sum.get_shape())
|
||||
|
||||
sum_out = sess.run(sp_sum)
|
||||
|
||||
|
@ -328,6 +328,12 @@ class SparseResetShapeTest(test_util.TensorFlowTestCase):
|
||||
return sparse_tensor.SparseTensorValue(self._IND_2_5_6, self._VAL_2_5_6,
|
||||
self._SHP_2_5_6)
|
||||
|
||||
def testStaticShapeInfoPreservedWhenNewShapeIsProvidedAndStatic(self):
|
||||
sp_input = self._SparseTensor_2x5x6()
|
||||
new_shape = np.array([3, 6, 7], dtype=np.int64)
|
||||
sp_output = sparse_ops.sparse_reset_shape(sp_input, new_shape)
|
||||
self.assertAllEqual([3, 6, 7], sp_output.get_shape())
|
||||
|
||||
def testBasic(self):
|
||||
with self.test_session(use_gpu=False) as sess:
|
||||
sp_input = self._SparseTensor_2x5x6()
|
||||
@ -397,14 +403,21 @@ class SparseResetShapeTest(test_util.TensorFlowTestCase):
|
||||
with self.assertRaisesOpError("x == y did not hold element-wise"):
|
||||
sess.run(out, feed_dict={new_shape: np.array([3, 7], dtype=np.int64)})
|
||||
|
||||
def testInvalidDimensionSize(self):
|
||||
with self.test_session(use_gpu=False) as sess:
|
||||
def testInvalidDimensionSizeStatic(self):
|
||||
sp_input = self._SparseTensor_2x5x6()
|
||||
new_shape = np.array([3, 7, 5], dtype=np.int64)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "should have dimension sizes"):
|
||||
sparse_ops.sparse_reset_shape(sp_input, new_shape)
|
||||
|
||||
def testInvalidDimensionSizeDynamic(self):
|
||||
with self.test_session(use_gpu=False) as sess:
|
||||
sp_input = self._SparseTensor_2x5x6()
|
||||
new_shape = array_ops.placeholder(dtype=dtypes.int32)
|
||||
out = sparse_ops.sparse_reset_shape(sp_input, new_shape)
|
||||
|
||||
with self.assertRaisesOpError("x <= y did not hold element-wise"):
|
||||
sess.run(out)
|
||||
sess.run(out, feed_dict={new_shape: [3, 7, 5]})
|
||||
|
||||
def testInvalidDimensionSizeInputUnavailableInGraphConstruction(self):
|
||||
sp_input = array_ops.sparse_placeholder(dtype=dtypes.int32)
|
||||
|
@ -48,6 +48,13 @@ class SparseReorderTest(test.TestCase):
|
||||
shape = np.array([5, 6]).astype(np.int64)
|
||||
return sparse_tensor.SparseTensorValue(ind, val, shape)
|
||||
|
||||
def testStaticShapeInfoPreserved(self):
|
||||
sp_input = sparse_tensor.SparseTensor.from_value(
|
||||
self._SparseTensorValue_5x6(np.arange(6)))
|
||||
self.assertAllEqual((5, 6), sp_input.get_shape())
|
||||
sp_output = sparse_ops.sparse_reorder(sp_input)
|
||||
self.assertAllEqual((5, 6), sp_output.get_shape())
|
||||
|
||||
def testAlreadyInOrder(self):
|
||||
with self.test_session(use_gpu=False) as sess:
|
||||
input_val = self._SparseTensorValue_5x6(np.arange(6))
|
||||
|
@ -50,6 +50,13 @@ class SparseReshapeTest(test.TestCase):
|
||||
shape = np.array([2, 3, 4])
|
||||
return sparse_tensor.SparseTensorValue(ind, val, shape)
|
||||
|
||||
def testStaticShapeInfoPreserved(self):
|
||||
sp_input = sparse_tensor.SparseTensor.from_value(
|
||||
self._SparseTensorValue_5x6())
|
||||
self.assertAllEqual((5, 6), sp_input.get_shape())
|
||||
sp_output = sparse_ops.sparse_reshape(sp_input, shape=(1, 5, 2, 3))
|
||||
self.assertAllEqual((1, 5, 2, 3), sp_output.get_shape())
|
||||
|
||||
def testSameShape(self):
|
||||
with self.test_session(use_gpu=False) as sess:
|
||||
input_val = self._SparseTensorValue_5x6()
|
||||
@ -180,6 +187,12 @@ class SparseReshapeTest(test.TestCase):
|
||||
with self.assertRaisesOpError("only one output shape size may be -1"):
|
||||
sess.run(sp_output, {sp_input: input_val})
|
||||
|
||||
def testProvideStaticallyMismatchedSizes(self):
|
||||
input_val = self._SparseTensorValue_5x6()
|
||||
sp_input = sparse_tensor.SparseTensor.from_value(input_val)
|
||||
with self.assertRaisesRegexp(ValueError, "Cannot reshape"):
|
||||
sparse_ops.sparse_reshape(sp_input, [4, 7])
|
||||
|
||||
def testFeedMismatchedSizes(self):
|
||||
with self.test_session(use_gpu=False) as sess:
|
||||
sp_input = self._SparseTensorPlaceholder()
|
||||
|
@ -774,6 +774,11 @@ class VariableScopeTest(test.TestCase):
|
||||
self.assertEqual([v.name
|
||||
for v in scope.global_variables()], ["foo/b:0"])
|
||||
|
||||
def testGetVariableWithRefDtype(self):
|
||||
v = variable_scope.get_variable("v", shape=[3, 4], dtype=dtypes.float32)
|
||||
# Ensure it is possible to do get_variable with a _ref dtype passed in.
|
||||
_ = variable_scope.get_variable("w", shape=[5, 6], dtype=v.dtype)
|
||||
|
||||
|
||||
def axis0_into1_partitioner(shape=None, **unused_kwargs):
|
||||
part = [1] * len(shape)
|
||||
|
@ -335,7 +335,7 @@ class Layer(object):
|
||||
|
||||
def add_variable(self, name, shape, dtype=None,
|
||||
initializer=None, regularizer=None, trainable=True):
|
||||
"""Adds a new variable to the layer.
|
||||
"""Adds a new variable to the layer, or gets an existing one; returns it.
|
||||
|
||||
Arguments:
|
||||
name: variable name.
|
||||
@ -424,7 +424,6 @@ class Layer(object):
|
||||
self.build(input_shapes[0])
|
||||
else:
|
||||
self.build(input_shapes)
|
||||
self.built = True
|
||||
if 'scope' in tf_inspect.getargspec(self.call).args:
|
||||
kwargs['scope'] = scope
|
||||
outputs = self.call(inputs, *args, **kwargs)
|
||||
@ -443,6 +442,7 @@ class Layer(object):
|
||||
|
||||
# Update global default collections.
|
||||
_add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS)
|
||||
self.built = True
|
||||
return outputs
|
||||
|
||||
@property
|
||||
|
@ -153,6 +153,36 @@ class BaseLayerTest(test.TestCase):
|
||||
self.assertEqual(layer.built, True)
|
||||
self.assertEqual(outputs.op.name, 'my_layer/Square')
|
||||
|
||||
def testFirstCallCanCreateVariablesButSecondCanNotWhenBuildEmpty(self):
|
||||
|
||||
class MyLayer(base_layers.Layer):
|
||||
|
||||
def build(self, _):
|
||||
# Do not mark the layer as built.
|
||||
pass
|
||||
|
||||
def call(self, inputs):
|
||||
self.my_var = self.add_variable('my_var', [2, 2])
|
||||
if self.built:
|
||||
# Skip creating on the first call; try to create after it's
|
||||
# built. This is expected to fail.
|
||||
self.add_variable('this_will_break_on_second_call', [2, 2])
|
||||
return inputs + math_ops.square(self.my_var)
|
||||
|
||||
layer = MyLayer(name='my_layer')
|
||||
inputs = random_ops.random_uniform((2,), seed=1)
|
||||
outputs = layer.apply(inputs)
|
||||
self.assertEqual(layer.built, True)
|
||||
self.assertEqual(outputs.op.name, 'my_layer/add')
|
||||
self.assertListEqual(
|
||||
[v.name for v in layer.variables], ['my_layer/my_var:0'])
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'my_layer/this_will_break_on_second_call'):
|
||||
layer.apply(inputs)
|
||||
# The list of variables hasn't changed.
|
||||
self.assertListEqual(
|
||||
[v.name for v in layer.variables], ['my_layer/my_var:0'])
|
||||
|
||||
def testDeepCopy(self):
|
||||
|
||||
class MyLayer(base_layers.Layer):
|
||||
|
@ -145,6 +145,7 @@ class _Conv(base.Layer):
|
||||
dtype=self.dtype)
|
||||
else:
|
||||
self.bias = None
|
||||
self.built = True
|
||||
|
||||
def call(self, inputs):
|
||||
outputs = nn.convolution(
|
||||
@ -837,6 +838,7 @@ class SeparableConv2D(Conv2D):
|
||||
dtype=self.dtype)
|
||||
else:
|
||||
self.bias = None
|
||||
self.built = True
|
||||
|
||||
def call(self, inputs):
|
||||
if self.data_format == 'channels_first':
|
||||
@ -1070,6 +1072,7 @@ class Conv2DTranspose(Conv2D):
|
||||
dtype=self.dtype)
|
||||
else:
|
||||
self.bias = None
|
||||
self.built = True
|
||||
|
||||
def call(self, inputs):
|
||||
inputs_shape = array_ops.shape(inputs)
|
||||
@ -1297,6 +1300,7 @@ class Conv3DTranspose(Conv3D):
|
||||
dtype=self.dtype)
|
||||
else:
|
||||
self.bias = None
|
||||
self.built = True
|
||||
|
||||
def call(self, inputs):
|
||||
inputs_shape = array_ops.shape(inputs)
|
||||
|
@ -130,6 +130,7 @@ class Dense(base.Layer):
|
||||
trainable=True)
|
||||
else:
|
||||
self.bias = None
|
||||
self.built = True
|
||||
|
||||
def call(self, inputs):
|
||||
inputs = ops.convert_to_tensor(inputs, dtype=self.dtype)
|
||||
|
@ -201,6 +201,7 @@ class BatchNormalization(base.Layer):
|
||||
'renorm_stddev_weight', ())
|
||||
finally:
|
||||
self._scope.set_partitioner(partitioner)
|
||||
self.built = True
|
||||
|
||||
def _renorm_correction_and_moments(self, mean, variance, training):
|
||||
"""Returns the correction and update values for renorm."""
|
||||
@ -399,7 +400,9 @@ def batch_normalization(inputs,
|
||||
training: Either a Python boolean, or a TensorFlow boolean scalar tensor
|
||||
(e.g. a placeholder). Whether to return the output in training mode
|
||||
(normalized with statistics of the current batch) or in inference mode
|
||||
(normalized with moving statistics).
|
||||
(normalized with moving statistics). **NOTE**: make sure to set this
|
||||
parameter correctly, or else your training/inference will not work
|
||||
properly.
|
||||
trainable: Boolean, if `True` also add variables to the graph collection
|
||||
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
|
||||
name: String, the name of the layer.
|
||||
|
@ -71,6 +71,7 @@ class _Pooling1D(base.Layer):
|
||||
if len(input_shape) != 3:
|
||||
raise ValueError('Inputs should have rank 3. '
|
||||
'Received input shape:', str(input_shape))
|
||||
self.built = True
|
||||
|
||||
def call(self, inputs):
|
||||
# There is no TF op for 1D pooling, hence we make the inputs 4D.
|
||||
@ -261,6 +262,7 @@ class _Pooling2D(base.Layer):
|
||||
if len(input_shape) != 4:
|
||||
raise ValueError('Inputs should have rank 4. '
|
||||
'Received input shape:', str(input_shape))
|
||||
self.built = True
|
||||
|
||||
def call(self, inputs):
|
||||
if self.data_format == 'channels_last':
|
||||
@ -448,6 +450,7 @@ class _Pooling3D(base.Layer):
|
||||
if len(input_shape) != 5:
|
||||
raise ValueError('Inputs should have rank 5. '
|
||||
'Received input shape:', str(input_shape))
|
||||
self.built = True
|
||||
|
||||
def call(self, inputs):
|
||||
pool_shape = (1,) + self.pool_size + (1,)
|
||||
|
@ -21,7 +21,6 @@ from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import hashlib
|
||||
import re
|
||||
import threading
|
||||
|
||||
import six
|
||||
@ -56,6 +55,7 @@ def _as_type_list(dtypes):
|
||||
def _as_shape_list(shapes, dtypes, unknown_dim_allowed=False,
|
||||
unknown_rank_allowed=False):
|
||||
"""Convert shapes to a list of tuples of int (or None)."""
|
||||
del dtypes
|
||||
if unknown_dim_allowed:
|
||||
if (not isinstance(shapes, collections.Sequence)
|
||||
or not shapes
|
||||
@ -925,16 +925,18 @@ class Barrier(object):
|
||||
If barrier has no completed elements, this operation will block
|
||||
until there are 'num_elements' elements to take.
|
||||
|
||||
TODO(b/25743580): the semantics of `allow_small_batch` are experimental
|
||||
and may be extended to other cases in the future.
|
||||
|
||||
TODO(ebrevdo): If a take_many(allow_small_batch=True) is blocking
|
||||
already when the barrier is closed, it will block for ever. Fix this
|
||||
by using asynchronous operations.
|
||||
|
||||
Args:
|
||||
num_elements: The number of elements to take.
|
||||
allow_small_batch: If the barrier is closed, don't block if there are less
|
||||
completed elements than requested, but instead return all available
|
||||
completed elements.
|
||||
TODO(b/25743580): the semantics of `allow_small_batch` are experimental
|
||||
and may be extended to other cases in the future.
|
||||
TODO(ebrevdo): If a take_many(allow_small_batch=True) is blocking
|
||||
already when the barrier is closed, it will block for ever. Fix this
|
||||
by using asynchronous operations.
|
||||
timeout: This specifies the number of milliseconds to block
|
||||
before returning with DEADLINE_EXCEEDED. (This option is not
|
||||
supported yet.)
|
||||
|
@ -51,6 +51,7 @@ import numpy as np
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -288,12 +289,21 @@ def sparse_add(a, b, thresh=0):
|
||||
|
||||
if all(isinstance(inp, sparse_classes) for inp in [a, b]):
|
||||
a = _convert_to_sparse_tensor(a)
|
||||
b = _convert_to_sparse_tensor(b)
|
||||
thresh = ops.convert_to_tensor(
|
||||
thresh, dtype=a.values.dtype.real_dtype, name="thresh")
|
||||
output_ind, output_val, output_shape = (gen_sparse_ops._sparse_add(
|
||||
a.indices, a.values, a.dense_shape,
|
||||
b.indices, b.values, b.dense_shape,
|
||||
thresh))
|
||||
|
||||
# Attempt to get output_shape statically.
|
||||
a.get_shape().assert_is_compatible_with(b.get_shape())
|
||||
static_shape = array_ops.broadcast_static_shape(
|
||||
a.get_shape(), b.get_shape())
|
||||
if static_shape.is_fully_defined():
|
||||
output_shape = static_shape.as_list()
|
||||
|
||||
return sparse_tensor.SparseTensor(output_ind, output_val, output_shape)
|
||||
else:
|
||||
# swap to make `a` the SparseTensor.
|
||||
@ -368,8 +378,12 @@ def sparse_reorder(sp_input, name=None):
|
||||
reordered_ind, reordered_val = (gen_sparse_ops._sparse_reorder(
|
||||
sp_input.indices, sp_input.values, sp_input.dense_shape, name=name))
|
||||
|
||||
return sparse_tensor.SparseTensor(reordered_ind, reordered_val,
|
||||
array_ops.identity(sp_input.dense_shape))
|
||||
if sp_input.get_shape().is_fully_defined():
|
||||
dense_shape = sp_input.get_shape().as_list()
|
||||
else:
|
||||
dense_shape = array_ops.identity(sp_input.dense_shape)
|
||||
|
||||
return sparse_tensor.SparseTensor(reordered_ind, reordered_val, dense_shape)
|
||||
|
||||
|
||||
def sparse_reshape(sp_input, shape, name=None):
|
||||
@ -416,13 +430,30 @@ def sparse_reshape(sp_input, shape, name=None):
|
||||
|
||||
Raises:
|
||||
TypeError: If `sp_input` is not a `SparseTensor`.
|
||||
ValueError: If argument `shape` requests a `SparseTensor` with a different
|
||||
number of elements than `sp_input`.
|
||||
"""
|
||||
sp_input = _convert_to_sparse_tensor(sp_input)
|
||||
shape = ops.convert_to_tensor(shape, dtype=dtypes.int64)
|
||||
|
||||
with ops.name_scope(name, "SparseReshape", [sp_input]) as name:
|
||||
reshaped_ind, reshaped_shape = gen_sparse_ops._sparse_reshape(
|
||||
sp_input.indices, sp_input.dense_shape, shape, name=name)
|
||||
|
||||
reshaped_shape_const = tensor_util.constant_value(shape)
|
||||
if (reshaped_shape_const is not None
|
||||
and sp_input.get_shape().is_fully_defined()):
|
||||
# Don't deal with inferred dimensions. That would add significant code.
|
||||
if all(n >= 0 for n in reshaped_shape_const):
|
||||
reshaped_size = np.prod(reshaped_shape_const)
|
||||
in_shape_size = np.prod(sp_input.get_shape().as_list())
|
||||
if reshaped_size != in_shape_size:
|
||||
raise ValueError(
|
||||
"Cannot reshape a tensor with %d elements to shape %s "
|
||||
"(%d elements)."
|
||||
% (in_shape_size, reshaped_shape_const, reshaped_size))
|
||||
reshaped_shape = reshaped_shape_const
|
||||
|
||||
return sparse_tensor.SparseTensor(
|
||||
reshaped_ind, array_ops.identity(sp_input.values),
|
||||
reshaped_shape)
|
||||
@ -986,6 +1017,8 @@ def sparse_reset_shape(sp_input, new_shape=None):
|
||||
TypeError: If `sp_input` is not a `SparseTensor`.
|
||||
ValueError: If `new_shape` represents a tensor with a different rank from
|
||||
that of `sp_input` (if shapes are known when graph is constructed).
|
||||
ValueError: If `new_shape` is determined during graph build to have
|
||||
dimension sizes that are too small.
|
||||
OpError:
|
||||
- If `new_shape` has dimension sizes that are too small.
|
||||
- If shapes are not known during graph construction time, and during run
|
||||
@ -1009,6 +1042,19 @@ def sparse_reset_shape(sp_input, new_shape=None):
|
||||
# error before the sparse_tensor.SparseTensor catches it.
|
||||
output_shape_tensor.get_shape()[0].merge_with(in_shape.get_shape()[0])
|
||||
|
||||
output_shape_tensor_const = tensor_util.constant_value(
|
||||
output_shape_tensor)
|
||||
# For cases where all shapes are known during graph construction
|
||||
if (output_shape_tensor_const is not None
|
||||
and sp_input.get_shape().is_fully_defined()):
|
||||
in_shape_const = np.array(sp_input.get_shape().as_list())
|
||||
if not np.all(in_shape_const <= output_shape_tensor_const):
|
||||
raise ValueError(
|
||||
"Requested new_shape should have dimension sizes >= sp_input.shape."
|
||||
" Found new_shape (%s), sp_input.shape (%s)."
|
||||
% (in_shape_const, output_shape_tensor_const))
|
||||
output_shape_tensor = output_shape_tensor_const
|
||||
else:
|
||||
# For cases where shape is not known during graph construction.
|
||||
output_shape_tensor = control_flow_ops.with_dependencies(
|
||||
[check_ops.assert_equal(
|
||||
|
@ -280,6 +280,17 @@ class _VariableStore(object):
|
||||
raise ValueError(
|
||||
"Passed a custom_getter which is not callable: %s" % custom_getter)
|
||||
|
||||
# If a *_ref type is passed in an error would be triggered further down the
|
||||
# stack. We prevent this using base_dtype to get a non-ref version of the
|
||||
# type, before doing anything else. When _ref types are removed in favour of
|
||||
# resources, this line can be removed.
|
||||
try:
|
||||
dtype = dtype.base_dtype
|
||||
except AttributeError:
|
||||
# .base_dtype not existing means that we will try and use the raw dtype
|
||||
# which was passed in - this might be a NumPy type which is valid.
|
||||
pass
|
||||
|
||||
# This is the main logic of get_variable. However, custom_getter
|
||||
# may override this logic. So we save it as a callable and pass
|
||||
# it to custom_getter.
|
||||
@ -1281,7 +1292,7 @@ def _pure_variable_scope(name_or_scope,
|
||||
well-defined semantics. Defaults to False (will later change to True).
|
||||
|
||||
Yields:
|
||||
A scope that can be to captured and reused.
|
||||
A scope that can be captured and reused.
|
||||
|
||||
Raises:
|
||||
ValueError: when trying to reuse within a create scope, or create within
|
||||
|
@ -56,20 +56,22 @@ Example output:
|
||||
To show all available information in the SavedModel:
|
||||
$saved_model_cli show --dir /tmp/saved_model --all
|
||||
|
||||
'run' command usage: saved_model_cli run [-h] --dir DIR --tag_set TAG_SET
|
||||
--signature_def SIGNATURE_DEF_KEY --inputs INPUTS
|
||||
[--outdir OUTDIR] [--overwrite]
|
||||
usage: saved_model_cli run [-h] --dir DIR --tag_set TAG_SET --signature_def
|
||||
SIGNATURE_DEF_KEY [--inputs INPUTS]
|
||||
[--input_exprs INPUT_EXPRS] [--outdir OUTDIR]
|
||||
[--overwrite] [--tf_debug]
|
||||
|
||||
Examples:
|
||||
To run input tensors from files through a MetaGraphDef and save the output
|
||||
tensors to files:
|
||||
$saved_model_cli run --dir /tmp/saved_model --tag_set serve
|
||||
--signature_def serving_default --inputs x:0=/tmp/124.npz,x2=/tmp/123.npy
|
||||
--outdir /tmp/out
|
||||
--signature_def serving_default --inputs x=/tmp/124.npz
|
||||
--input_exprs 'x2=np.ones((6,2))' --outdir /tmp/out
|
||||
|
||||
To observe the intermediate Tensor values in the runtime graph, use the
|
||||
--tf_debug flag, e.g.:
|
||||
$saved_model_cli run --dir /tmp/saved_model --tag_set serve
|
||||
--signature_def serving_default --inputs x:0=/tmp/124.npz,x2=/tmp/123.npy
|
||||
--signature_def serving_default --inputs 'x=/tmp/124.npz;x2=/tmp/123.npy'
|
||||
--outdir /tmp/out --tf_debug
|
||||
|
||||
To build this tool from source, run:
|
||||
@ -367,7 +369,7 @@ def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key,
|
||||
output_full_path))
|
||||
|
||||
|
||||
def preprocess_input_arg_string(inputs_str):
|
||||
def preprocess_inputs_arg_string(inputs_str):
|
||||
"""Parses input arg into dictionary that maps input to file/variable tuple.
|
||||
|
||||
Parses input string in the format of, for example,
|
||||
@ -375,74 +377,94 @@ def preprocess_input_arg_string(inputs_str):
|
||||
dictionary looks like
|
||||
{'input_key1': (filename1, variable_name1),
|
||||
'input_key2': (file2, None)}
|
||||
, which maps input keys to a tuple of file name and varaible name(None if
|
||||
, which maps input keys to a tuple of file name and variable name(None if
|
||||
empty).
|
||||
|
||||
Args:
|
||||
inputs_str: A string that specified where to load inputs. Each input is
|
||||
separated by comma.
|
||||
* If the command line arg for inputs is quoted and contains
|
||||
whitespace(s), all whitespaces will be ignored.
|
||||
inputs_str: A string that specified where to load inputs. Inputs are
|
||||
separated by semicolons.
|
||||
* For each input key:
|
||||
'input=filename<[variable_name]>'
|
||||
* The "[variable_name]" key is optional. Will be set to None if not
|
||||
specified.
|
||||
'<input_key>=<filename>' or
|
||||
'<input_key>=<filename>[<variable_name>]'
|
||||
* The optional 'variable_name' key will be set to None if not specified.
|
||||
|
||||
Returns:
|
||||
A dictionary that maps input keys to a tuple of file name and varaible name.
|
||||
A dictionary that maps input keys to a tuple of file name and variable name.
|
||||
|
||||
Raises:
|
||||
RuntimeError: An error when the given input is in a bad format.
|
||||
RuntimeError: An error when the given input string is in a bad format.
|
||||
"""
|
||||
input_dict = {}
|
||||
inputs_raw = inputs_str.split(',')
|
||||
inputs_raw = inputs_str.split(';')
|
||||
for input_raw in filter(bool, inputs_raw): # skip empty strings
|
||||
# Remove quotes and whitespaces
|
||||
input_raw = input_raw.replace('"', '').replace('\'', '').replace(' ', '')
|
||||
|
||||
# Format of input=filename[variable_name]'
|
||||
match = re.match(r'^([\w\-]+)=([\w\-.\/]+)\[([\w\-]+)\]$', input_raw)
|
||||
match = re.match(r'([^=]+)=([^\[\]]+)\[([^\[\]]+)\]$', input_raw)
|
||||
|
||||
if match:
|
||||
input_dict[match.group(1)] = (match.group(2), match.group(3))
|
||||
input_dict[match.group(1)] = match.group(2), match.group(3)
|
||||
else:
|
||||
# Format of input=filename'
|
||||
match = re.match(r'^([\w\-]+)=([\w\-.\/]+)$', input_raw)
|
||||
match = re.match(r'([^=]+)=([^\[\]]+)$', input_raw)
|
||||
if match:
|
||||
input_dict[match.group(1)] = (match.group(2), None)
|
||||
input_dict[match.group(1)] = match.group(2), None
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'Input \"%s\" format is incorrect. Please follow \"--inputs '
|
||||
'input_key=file_name[variable_name]\" or input_key=file_name' %
|
||||
input_raw)
|
||||
'--inputs "%s" format is incorrect. Please follow'
|
||||
'"<input_key>=<filename>", or'
|
||||
'"<input_key>=<filename>[<variable_name>]"' % input_raw)
|
||||
|
||||
return input_dict
|
||||
|
||||
|
||||
def load_inputs_from_input_arg_string(inputs_str):
|
||||
"""Parses input arg string and load inputs into a dictionary.
|
||||
def preprocess_input_exprs_arg_string(input_exprs_str):
|
||||
"""Parses input arg into dictionary that maps input key to python expression.
|
||||
|
||||
Parses input string in the format of, for example,
|
||||
"input1=filename1[variable_name1],input2=filename2" into a
|
||||
dictionary looks like
|
||||
{'input1:0': ndarray_saved_as_variable_name1_in_filename1 ,
|
||||
'input2:0': ndarray_saved_in_filename2}
|
||||
, which maps input keys to a numpy ndarray loaded from file. See Args section
|
||||
for more details on inputs format.
|
||||
Parses input string in the format of 'input_key=<python expression>' into a
|
||||
dictionary that maps each input_key to its python expression.
|
||||
|
||||
Args:
|
||||
input_exprs_str: A string that specifies python expression for input keys.
|
||||
Each input is separated by semicolon. For each input key:
|
||||
'input_key=<python expression>'
|
||||
|
||||
Returns:
|
||||
A dictionary that maps input keys to python expressions.
|
||||
|
||||
Raises:
|
||||
RuntimeError: An error when the given input string is in a bad format.
|
||||
"""
|
||||
input_dict = {}
|
||||
|
||||
for input_raw in filter(bool, input_exprs_str.split(';')):
|
||||
if '=' not in input_exprs_str:
|
||||
raise RuntimeError('--input_exprs "%s" format is incorrect. Please follow'
|
||||
'"<input_key>=<python expression>"' % input_exprs_str)
|
||||
input_key, expr = input_raw.split('=')
|
||||
input_dict[input_key] = expr
|
||||
|
||||
return input_dict
|
||||
|
||||
|
||||
def load_inputs_from_input_arg_string(inputs_str, input_exprs_str):
|
||||
"""Parses input arg strings and create inputs feed_dict.
|
||||
|
||||
Parses '--inputs' string for inputs to be loaded from file, and parses
|
||||
'--input_exprs' string for inputs to be evaluated from python expression.
|
||||
|
||||
Args:
|
||||
inputs_str: A string that specified where to load inputs. Each input is
|
||||
separated by comma.
|
||||
* If the command line arg for inputs is quoted and contains
|
||||
whitespace(s), all whitespaces will be ignored.
|
||||
separated by semicolon.
|
||||
* For each input key:
|
||||
'input=filename[variable_name]'
|
||||
'<input_key>=<filename>' or
|
||||
'<input_key>=<filename>[<variable_name>]'
|
||||
* The optional 'variable_name' key will be set to None if not specified.
|
||||
* File specified by 'filename' will be loaded using numpy.load. Inputs
|
||||
can be loaded from only .npy, .npz or pickle files.
|
||||
* The "[variable_name]" key is optional depending on the input file type
|
||||
as descripted in more details below.
|
||||
When loading from a npy file, which always contains a numpy ndarray, the
|
||||
content will be directly assigned to the specified input tensor. If a
|
||||
varaible_name is specified, it will be ignored and a warning will be
|
||||
variable_name is specified, it will be ignored and a warning will be
|
||||
issued.
|
||||
When loading from a npz zip file, user can specify which variable within
|
||||
the zip file to load for the input tensor inside the square brackets. If
|
||||
@ -453,10 +475,12 @@ def load_inputs_from_input_arg_string(inputs_str):
|
||||
to the specified input tensor, else SavedModel CLI will assume a
|
||||
dictionary is stored in the pickle file and the value corresponding to
|
||||
the variable_name will be used.
|
||||
input_exprs_str: A string that specified python expressions for inputs.
|
||||
* In the format of: '<input_key>=<python expression>'.
|
||||
* numpy module is available as np.
|
||||
|
||||
Returns:
|
||||
A dictionary that maps input tensor keys to a numpy ndarray loaded from
|
||||
file.
|
||||
A dictionary that maps input tensor keys to numpy ndarrays.
|
||||
|
||||
Raises:
|
||||
RuntimeError: An error when a key is specified, but the input file contains
|
||||
@ -466,13 +490,14 @@ def load_inputs_from_input_arg_string(inputs_str):
|
||||
"""
|
||||
tensor_key_feed_dict = {}
|
||||
|
||||
for input_tensor_key, (
|
||||
filename,
|
||||
variable_name) in preprocess_input_arg_string(inputs_str).items():
|
||||
# When a variable_name key is specified for the input file
|
||||
if variable_name:
|
||||
inputs = preprocess_inputs_arg_string(inputs_str)
|
||||
input_exprs = preprocess_input_exprs_arg_string(input_exprs_str)
|
||||
|
||||
for input_tensor_key, (filename, variable_name) in inputs.items():
|
||||
data = np.load(filename)
|
||||
|
||||
# When a variable_name key is specified for the input file
|
||||
if variable_name:
|
||||
# if file contains a single ndarray, ignore the input name
|
||||
if isinstance(data, np.ndarray):
|
||||
warnings.warn(
|
||||
@ -488,7 +513,6 @@ def load_inputs_from_input_arg_string(inputs_str):
|
||||
(filename, variable_name))
|
||||
# When no key is specified for the input file.
|
||||
else:
|
||||
data = np.load(filename)
|
||||
# Check if npz file only contains a single numpy ndarray.
|
||||
if isinstance(data, np.lib.npyio.NpzFile):
|
||||
variable_name_list = data.files
|
||||
@ -500,6 +524,16 @@ def load_inputs_from_input_arg_string(inputs_str):
|
||||
else:
|
||||
tensor_key_feed_dict[input_tensor_key] = data
|
||||
|
||||
# When input is a python expression:
|
||||
for input_tensor_key, py_expr in input_exprs.items():
|
||||
if input_tensor_key in tensor_key_feed_dict:
|
||||
warnings.warn(
|
||||
'input_key %s has been specified with both --inputs and --input_exprs'
|
||||
' options. Value in --input_exprs will be used.' % input_tensor_key)
|
||||
|
||||
# ast.literal_eval does not work with numpy expressions
|
||||
tensor_key_feed_dict[input_tensor_key] = eval(py_expr) # pylint: disable=eval-used
|
||||
|
||||
return tensor_key_feed_dict
|
||||
|
||||
|
||||
@ -531,7 +565,8 @@ def run(args):
|
||||
Args:
|
||||
args: A namespace parsed from command line.
|
||||
"""
|
||||
tensor_key_feed_dict = load_inputs_from_input_arg_string(args.inputs)
|
||||
tensor_key_feed_dict = load_inputs_from_input_arg_string(
|
||||
args.inputs, args.input_exprs)
|
||||
run_saved_model_with_feed_dict(args.dir, args.tag_set, args.signature_def,
|
||||
tensor_key_feed_dict, args.outdir,
|
||||
args.overwrite, tf_debug=args.tf_debug)
|
||||
@ -559,7 +594,7 @@ def create_parser():
|
||||
'MetaGraphDef specified by its tag-set:\n'
|
||||
'$saved_model_cli show --dir /tmp/saved_model --tag_set serve\n'
|
||||
'For a MetaGraphDef with multiple tags in the tag-set, all tags must be '
|
||||
'passed in, separated by \',\':\n'
|
||||
'passed in, separated by \';\':\n'
|
||||
'$saved_model_cli show --dir /tmp/saved_model --tag_set serve,gpu\n\n'
|
||||
'To show all inputs and outputs TensorInfo for a specific'
|
||||
' SignatureDef specified by the SignatureDef key in a'
|
||||
@ -601,7 +636,7 @@ def create_parser():
|
||||
'$saved_model_cli show --dir /tmp/saved_model --tag_set serve'
|
||||
'--signature_def serving_default '
|
||||
'--inputs input1_key=/tmp/124.npz[x],input2_key=/tmp/123.npy'
|
||||
'--outdir=/out\n\n'
|
||||
'--input_exprs \'input3_key=np.ones(2)\' --outdir=/out\n\n'
|
||||
'For more information about input file format, please see:\n'
|
||||
'https://www.tensorflow.org/programmers_guide/saved_model_cli\n')
|
||||
parser_run = subparsers.add_parser(
|
||||
@ -622,10 +657,15 @@ def create_parser():
|
||||
required=True,
|
||||
metavar='SIGNATURE_DEF_KEY',
|
||||
help='key of SignatureDef to run')
|
||||
msg = ('inputs in the format of \'input_key=filename[variable_name]\', '
|
||||
'separated by \',\'. Inputs can only be loaded from .npy, .npz or '
|
||||
'pickle files. Please use input keys instead of input names.')
|
||||
parser_run.add_argument('--inputs', type=str, required=True, help=msg)
|
||||
msg = ('Loading inputs from files, in the format of \'<input_key>=<filename>,'
|
||||
' or \'<input_key>=<filename>[<variable_name>]\', separated by \';\'.'
|
||||
' The file format can only be from .npy, .npz or pickle.')
|
||||
parser_run.add_argument('--inputs', type=str, default='', help=msg)
|
||||
msg = ('Specifying inputs by python expressions, in the format of'
|
||||
' "<input_key>=\'<python expression>\'", separated by \';\'. '
|
||||
'numpy module is available as \'np\'. '
|
||||
'Will override duplicate input_keys from --inputs option.')
|
||||
parser_run.add_argument('--input_exprs', type=str, default='', help=msg)
|
||||
parser_run.add_argument(
|
||||
'--outdir',
|
||||
type=str,
|
||||
@ -649,6 +689,8 @@ def create_parser():
|
||||
def main():
|
||||
parser = create_parser()
|
||||
args = parser.parse_args()
|
||||
if not args.inputs and not args.input_exprs:
|
||||
args.error('At least one of --inputs and --input_exprs is required')
|
||||
args.func(args)
|
||||
|
||||
|
||||
|
@ -201,28 +201,37 @@ Method name is: tensorflow/serving/predict"""
|
||||
self.assertEqual(err.getvalue().strip(), '')
|
||||
|
||||
def testInputPreProcessFormats(self):
|
||||
input_str = 'input1=/path/file.txt[ab3], input2=file2,,'
|
||||
input_dict = saved_model_cli.preprocess_input_arg_string(input_str)
|
||||
input_str = 'input1=/path/file.txt[ab3];input2=file2'
|
||||
input_expr_str = 'input3=np.zeros([2,2]);input4=[4,5]'
|
||||
input_dict = saved_model_cli.preprocess_inputs_arg_string(input_str)
|
||||
input_expr_dict = saved_model_cli.preprocess_input_exprs_arg_string(
|
||||
input_expr_str)
|
||||
self.assertTrue(input_dict['input1'] == ('/path/file.txt', 'ab3'))
|
||||
self.assertTrue(input_dict['input2'] == ('file2', None))
|
||||
|
||||
def testInputPreProcessQuoteAndWhitespace(self):
|
||||
input_str = '\' input1 = file[v_1]\', input2=file ["sd"] '
|
||||
input_dict = saved_model_cli.preprocess_input_arg_string(input_str)
|
||||
self.assertTrue(input_dict['input1'] == ('file', 'v_1'))
|
||||
self.assertTrue(input_dict['input2'] == ('file', 'sd'))
|
||||
self.assertTrue(input_expr_dict['input3'] == 'np.zeros([2,2])')
|
||||
self.assertTrue(input_expr_dict['input4'] == '[4,5]')
|
||||
self.assertTrue(len(input_dict) == 2)
|
||||
self.assertTrue(len(input_expr_dict) == 2)
|
||||
|
||||
def testInputPreProcessFileNames(self):
|
||||
input_str = (r'inputx=C:\Program Files\data.npz[v:0];'
|
||||
r'input:0=c:\PROGRA~1\data.npy')
|
||||
input_dict = saved_model_cli.preprocess_inputs_arg_string(input_str)
|
||||
print(input_dict)
|
||||
self.assertTrue(input_dict['inputx'] == (r'C:\Program Files\data.npz',
|
||||
'v:0'))
|
||||
self.assertTrue(input_dict['input:0'] == (r'c:\PROGRA~1\data.npy', None))
|
||||
|
||||
def testInputPreProcessErrorBadFormat(self):
|
||||
input_str = 'inputx=file[[v1]v2'
|
||||
with self.assertRaises(RuntimeError):
|
||||
saved_model_cli.preprocess_input_arg_string(input_str)
|
||||
saved_model_cli.preprocess_inputs_arg_string(input_str)
|
||||
input_str = 'inputx:file'
|
||||
with self.assertRaises(RuntimeError):
|
||||
saved_model_cli.preprocess_input_arg_string(input_str)
|
||||
input_str = 'inputx=file(v_1)'
|
||||
saved_model_cli.preprocess_inputs_arg_string(input_str)
|
||||
input_str = 'inputx:np.zeros((5))'
|
||||
with self.assertRaises(RuntimeError):
|
||||
saved_model_cli.preprocess_input_arg_string(input_str)
|
||||
saved_model_cli.preprocess_input_exprs_arg_string(input_str)
|
||||
|
||||
def testInputParserNPY(self):
|
||||
x0 = np.array([[1], [2]])
|
||||
@ -231,8 +240,8 @@ Method name is: tensorflow/serving/predict"""
|
||||
input1_path = os.path.join(test.get_temp_dir(), 'input1.npy')
|
||||
np.save(input0_path, x0)
|
||||
np.save(input1_path, x1)
|
||||
input_str = 'x0=' + input0_path + '[x0],x1=' + input1_path
|
||||
feed_dict = saved_model_cli.load_inputs_from_input_arg_string(input_str)
|
||||
input_str = 'x0=' + input0_path + '[x0];x1=' + input1_path
|
||||
feed_dict = saved_model_cli.load_inputs_from_input_arg_string(input_str, '')
|
||||
self.assertTrue(np.all(feed_dict['x0'] == x0))
|
||||
self.assertTrue(np.all(feed_dict['x1'] == x1))
|
||||
|
||||
@ -240,8 +249,8 @@ Method name is: tensorflow/serving/predict"""
|
||||
x0 = np.array([[1], [2]])
|
||||
input_path = os.path.join(test.get_temp_dir(), 'input.npz')
|
||||
np.savez(input_path, a=x0)
|
||||
input_str = 'x=' + input_path + '[a],y=' + input_path
|
||||
feed_dict = saved_model_cli.load_inputs_from_input_arg_string(input_str)
|
||||
input_str = 'x=' + input_path + '[a];y=' + input_path
|
||||
feed_dict = saved_model_cli.load_inputs_from_input_arg_string(input_str, '')
|
||||
self.assertTrue(np.all(feed_dict['x'] == x0))
|
||||
self.assertTrue(np.all(feed_dict['y'] == x0))
|
||||
|
||||
@ -258,25 +267,50 @@ Method name is: tensorflow/serving/predict"""
|
||||
pickle.dump(pkl1, f)
|
||||
with open(input_path2, 'wb') as f:
|
||||
pickle.dump(pkl2, f)
|
||||
input_str = 'x=' + input_path0 + '[b],y=' + input_path1 + '[c],'
|
||||
input_str = 'x=' + input_path0 + '[b];y=' + input_path1 + '[c];'
|
||||
input_str += 'z=' + input_path2
|
||||
feed_dict = saved_model_cli.load_inputs_from_input_arg_string(input_str)
|
||||
feed_dict = saved_model_cli.load_inputs_from_input_arg_string(input_str, '')
|
||||
self.assertTrue(np.all(feed_dict['x'] == pkl0['b']))
|
||||
self.assertTrue(np.all(feed_dict['y'] == pkl1))
|
||||
self.assertTrue(np.all(feed_dict['z'] == pkl2))
|
||||
|
||||
def testInputParserQuoteAndWhitespace(self):
|
||||
def testInputParserPythonExpression(self):
|
||||
x1 = np.ones([2, 10])
|
||||
x2 = np.array([[1], [2], [3]])
|
||||
x3 = np.mgrid[0:5, 0:5]
|
||||
x4 = [[3], [4]]
|
||||
input_expr_str = ('x1=np.ones([2,10]);x2=np.array([[1],[2],[3]]);'
|
||||
'x3=np.mgrid[0:5,0:5];x4=[[3],[4]]')
|
||||
feed_dict = saved_model_cli.load_inputs_from_input_arg_string(
|
||||
'', input_expr_str)
|
||||
self.assertTrue(np.all(feed_dict['x1'] == x1))
|
||||
self.assertTrue(np.all(feed_dict['x2'] == x2))
|
||||
self.assertTrue(np.all(feed_dict['x3'] == x3))
|
||||
self.assertTrue(np.all(feed_dict['x4'] == x4))
|
||||
|
||||
def testInputParserBoth(self):
|
||||
x0 = np.array([[1], [2]])
|
||||
x1 = np.array(range(6)).reshape(2, 3)
|
||||
input0_path = os.path.join(test.get_temp_dir(), 'input0.npy')
|
||||
input1_path = os.path.join(test.get_temp_dir(), 'input1.npy')
|
||||
np.save(input0_path, x0)
|
||||
np.save(input1_path, x1)
|
||||
input_str = '"x0=' + input0_path + '[x0] , x1 = ' + input1_path + '"'
|
||||
feed_dict = saved_model_cli.load_inputs_from_input_arg_string(input_str)
|
||||
input_path = os.path.join(test.get_temp_dir(), 'input.npz')
|
||||
np.savez(input_path, a=x0)
|
||||
x1 = np.ones([2, 10])
|
||||
input_str = 'x0=' + input_path + '[a]'
|
||||
input_expr_str = 'x1=np.ones([2,10])'
|
||||
feed_dict = saved_model_cli.load_inputs_from_input_arg_string(
|
||||
input_str, input_expr_str)
|
||||
self.assertTrue(np.all(feed_dict['x0'] == x0))
|
||||
self.assertTrue(np.all(feed_dict['x1'] == x1))
|
||||
|
||||
def testInputParserBothDuplicate(self):
|
||||
x0 = np.array([[1], [2]])
|
||||
input_path = os.path.join(test.get_temp_dir(), 'input.npz')
|
||||
np.savez(input_path, a=x0)
|
||||
x1 = np.ones([2, 10])
|
||||
input_str = 'x0=' + input_path + '[a]'
|
||||
input_expr_str = 'x0=np.ones([2,10])'
|
||||
feed_dict = saved_model_cli.load_inputs_from_input_arg_string(
|
||||
input_str, input_expr_str)
|
||||
self.assertTrue(np.all(feed_dict['x0'] == x1))
|
||||
|
||||
def testInputParserErrorNoName(self):
|
||||
x0 = np.array([[1], [2]])
|
||||
x1 = np.array(range(5))
|
||||
@ -284,7 +318,7 @@ Method name is: tensorflow/serving/predict"""
|
||||
np.savez(input_path, a=x0, b=x1)
|
||||
input_str = 'x=' + input_path
|
||||
with self.assertRaises(RuntimeError):
|
||||
saved_model_cli.load_inputs_from_input_arg_string(input_str)
|
||||
saved_model_cli.load_inputs_from_input_arg_string(input_str, '')
|
||||
|
||||
def testInputParserErrorWrongName(self):
|
||||
x0 = np.array([[1], [2]])
|
||||
@ -293,7 +327,7 @@ Method name is: tensorflow/serving/predict"""
|
||||
np.savez(input_path, a=x0, b=x1)
|
||||
input_str = 'x=' + input_path + '[c]'
|
||||
with self.assertRaises(RuntimeError):
|
||||
saved_model_cli.load_inputs_from_input_arg_string(input_str)
|
||||
saved_model_cli.load_inputs_from_input_arg_string(input_str, '')
|
||||
|
||||
def testRunCommandExistingOutdir(self):
|
||||
self.parser = saved_model_cli.create_parser()
|
||||
|
@ -994,7 +994,7 @@ class SVSummaryThread(coordinator.LooperThread):
|
||||
summary_strs = self._sess.run(self._sv.summary_op)
|
||||
global_step = None
|
||||
if self._sv.summary_writer:
|
||||
logging.info("Recording summary at step %d.", global_step)
|
||||
logging.info("Recording summary at step %s.", global_step)
|
||||
self._sv.summary_writer.add_summary(summary_strs, global_step)
|
||||
|
||||
|
||||
|
@ -480,27 +480,56 @@ bool DeviceOptionsToContextFlags(DeviceOptions device_options, int *flags) {
|
||||
CUdevice device, DeviceOptions device_options, CudaContext** context) {
|
||||
*context = nullptr;
|
||||
|
||||
CUcontext former_context = CurrentContext();
|
||||
if (former_context != nullptr) {
|
||||
LOG(WARNING) << "creating context when one is currently active; existing: "
|
||||
<< former_context;
|
||||
}
|
||||
|
||||
int flags = 0;
|
||||
if (!DeviceOptionsToContextFlags(device_options, &flags)) {
|
||||
LOG(WARNING) << "could not convert all device options into context flags";
|
||||
}
|
||||
|
||||
CUresult res;
|
||||
CUcontext former_context;
|
||||
CUcontext new_context;
|
||||
{
|
||||
// TODO(leary) Need to see if NVIDIA can expunge the leakiness in their
|
||||
// context creation: see http://b/13248943
|
||||
|
||||
#if CUDA_VERSION >= 7000
|
||||
res = cuDevicePrimaryCtxSetFlags(device, flags);
|
||||
{
|
||||
unsigned int former_primary_context_flags;
|
||||
int former_primary_context_is_active;
|
||||
CHECK_EQ(CUDA_SUCCESS,
|
||||
cuDevicePrimaryCtxGetState(device, &former_primary_context_flags,
|
||||
&former_primary_context_is_active));
|
||||
if (former_primary_context_flags != flags) {
|
||||
if (former_primary_context_is_active) {
|
||||
LOG(ERROR)
|
||||
<< "The primary context is active and has a different flag set ("
|
||||
<< former_primary_context_flags << ") than the desired flag set ("
|
||||
<< flags << ").";
|
||||
} else {
|
||||
CHECK_EQ(CUDA_SUCCESS, cuDevicePrimaryCtxSetFlags(device, flags));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
former_context = CUDADriver::CurrentContextOrDie();
|
||||
res = cuDevicePrimaryCtxRetain(&new_context, device);
|
||||
if (former_context != nullptr) {
|
||||
if (former_context == new_context) {
|
||||
VLOG(2) << "The primary context " << former_context
|
||||
<< " exists before initializing the StreamExecutor.";
|
||||
} else {
|
||||
LOG(WARNING) << "A non-primary context " << former_context
|
||||
<< " exists before initializing the StreamExecutor. We "
|
||||
"haven't verified StreamExecutor works with that.";
|
||||
}
|
||||
}
|
||||
#else
|
||||
former_context = CurrentContext();
|
||||
if (former_context != nullptr) {
|
||||
LOG(WARNING)
|
||||
<< "creating context when one is currently active; existing: "
|
||||
<< former_context;
|
||||
}
|
||||
res = cuCtxCreate(&new_context, flags, device);
|
||||
#endif
|
||||
}
|
||||
|
@ -334,8 +334,8 @@ int Main(int argc, char** argv) {
|
||||
Flag("show_memory", &show_memory, "whether to list stats by memory used"),
|
||||
Flag("memory_limit", &memory_limit,
|
||||
"how many items to show by memory used"),
|
||||
Flag("show_type", &show_time, "whether to list stats by op type"),
|
||||
Flag("show_summary", &show_time,
|
||||
Flag("show_type", &show_type, "whether to list stats by op type"),
|
||||
Flag("show_summary", &show_summary,
|
||||
"whether to show a summary of the stats"),
|
||||
Flag("show_flops", &show_flops, "whether to estimate the model's FLOPs"),
|
||||
Flag("warmup_runs", &warmup_runs, "how many runs to initialize model"),
|
||||
|
@ -9,7 +9,7 @@ exports_files(["LICENSE"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||
|
||||
py_library(
|
||||
py_binary(
|
||||
name = "grpc_tensorflow_server",
|
||||
srcs = [
|
||||
"grpc_tensorflow_server.py",
|
||||
|
12
tensorflow/tools/dist_test/server/grpc_tensorflow_server.py
Executable file → Normal file
12
tensorflow/tools/dist_test/server/grpc_tensorflow_server.py
Executable file → Normal file
@ -36,6 +36,7 @@ from __future__ import print_function
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.core.protobuf import tensorflow_server_pb2
|
||||
from tensorflow.python.platform import app
|
||||
from tensorflow.python.training import server_lib
|
||||
@ -103,8 +104,11 @@ def main(unused_args):
|
||||
raise ValueError("Invalid task_id: %d" % FLAGS.task_id)
|
||||
server_def.task_index = FLAGS.task_id
|
||||
|
||||
config = config_pb2.ConfigProto(gpu_options=config_pb2.GPUOptions(
|
||||
per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction))
|
||||
|
||||
# Create GRPC Server instance
|
||||
server = server_lib.Server(server_def)
|
||||
server = server_lib.Server(server_def, config=config)
|
||||
|
||||
# join() is blocking, unlike start()
|
||||
server.join()
|
||||
@ -137,6 +141,11 @@ if __name__ == "__main__":
|
||||
default=0,
|
||||
help="Task index, e.g., 0"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpu_memory_fraction",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Fraction of GPU memory allocated",)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
type="bool",
|
||||
@ -145,5 +154,6 @@ if __name__ == "__main__":
|
||||
default=False,
|
||||
help="Verbose mode"
|
||||
)
|
||||
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||
|
3
third_party/ortools.BUILD
vendored
3
third_party/ortools.BUILD
vendored
@ -11,6 +11,3 @@ native.cc_library(
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
@ -1,4 +1,18 @@
|
||||
#!/usr/bin/env bash
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
set -u # Check for undefined variables
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user