Merge pull request #14814 from yifeif/branch_176709725

Branch 176709725
This commit is contained in:
Gunhan Gulsoy 2017-11-22 20:35:17 -08:00 committed by GitHub
commit ab0fcaceda
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
142 changed files with 3041 additions and 1538 deletions

View File

@ -905,6 +905,28 @@ def set_trisycl_include_dir(environ_cp):
write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR',
trisycl_include_dir)
def set_trisycl_include_dir(environ_cp):
"""Set TRISYCL_INCLUDE_DIR."""
ask_trisycl_include_dir = ('Please specify the location of the triSYCL '
'include directory. (Use --config=sycl_trisycl '
'when building with Bazel) '
'[Default is %s]: ') % (
_DEFAULT_TRISYCL_INCLUDE_DIR)
while True:
trisycl_include_dir = get_from_env_or_user_or_default(
environ_cp, 'TRISYCL_INCLUDE_DIR', ask_trisycl_include_dir,
_DEFAULT_TRISYCL_INCLUDE_DIR)
if os.path.exists(trisycl_include_dir):
break
print('Invalid triSYCL include directory, %s cannot be found' %
(trisycl_include_dir))
# Set TRISYCL_INCLUDE_DIR
environ_cp['TRISYCL_INCLUDE_DIR'] = trisycl_include_dir
write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR', trisycl_include_dir)
def set_mpi_home(environ_cp):
"""Set MPI_HOME."""
default_mpi_home = which('mpirun') or which('mpiexec') or ''

View File

@ -189,7 +189,7 @@ def tf_library(name, graph, config,
" --cpp_class=" + cpp_class +
" --target_triple=" + target_llvm_triple() +
" --out_session_module=$(@D)/" + session_module_pb +
flags),
" " + flags),
tools=[tfcompile_tool],
visibility=visibility,
testonly=testonly,

View File

@ -76,7 +76,8 @@ class FusedBatchNormTest(XLATestCase):
# To avoid constant folding
t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x")
scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale")
offset = array_ops.placeholder(np.float32, shape=scale_shape, name="offset")
offset = array_ops.placeholder(
np.float32, shape=scale_shape, name="offset")
epsilon = 0.001
y_ref, mean_ref, var_ref = self._reference_training(
x_val, scale_val, offset_val, epsilon, data_format)
@ -112,7 +113,8 @@ class FusedBatchNormTest(XLATestCase):
# To avoid constant folding
t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x")
scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale")
offset = array_ops.placeholder(np.float32, shape=scale_shape, name="offset")
offset = array_ops.placeholder(
np.float32, shape=scale_shape, name="offset")
epsilon = 0.001
y, mean, var = nn.fused_batch_norm(
t_val,

View File

@ -67,6 +67,15 @@ class Client {
std::vector<GlobalData*> arguments;
ExecutionOptions execution_options;
ExecutionProfile* execution_profile;
ComputationInstance(const Computation& computation,
std::vector<GlobalData*> arguments,
ExecutionOptions execution_options,
ExecutionProfile* execution_profile)
: computation(computation),
arguments(std::move(arguments)),
execution_options(execution_options),
execution_profile(execution_profile) {}
};
// Executes a list ComputationInstances and returns global data produced from
@ -133,7 +142,7 @@ class Client {
// Returns a vector of global data handles that point to the tuple elements.
StatusOr<std::vector<std::unique_ptr<GlobalData>>> DeconstructTuple(
const GlobalData& computation);
const GlobalData& data);
// Retrieves the statistics of the given computation.
StatusOr<ComputationStats> GetComputationStats(

View File

@ -85,9 +85,9 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault {
HloOpcode opcode) {
HloComputation::Builder b("scalar_computation");
auto scalar_lhs = b.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {}), "scalar_lhs"));
0, ShapeUtil::MakeShape(primitive_type, {}), "scalar_lhs"));
auto scalar_rhs = b.AddInstruction(HloInstruction::CreateParameter(
1, ShapeUtil::MakeShape(F32, {}), "scalar_rhs"));
1, ShapeUtil::MakeShape(primitive_type, {}), "scalar_rhs"));
auto scalar_op = b.AddInstruction(
HloInstruction::CreateBinary(ShapeUtil::MakeShape(primitive_type, {}),
opcode, scalar_lhs, scalar_rhs));
@ -152,22 +152,30 @@ Status BatchNormRewriterVisitor::HandleBatchNormTraining(
// Expand batch norm training into smaller HLO ops.
HloInstruction* operand = batch_norm->mutable_operand(0);
const Shape operand_shape = operand->shape();
PrimitiveType ptype = operand_shape.element_type();
int64 feature_index = batch_norm->feature_index();
const int64 feature_count = operand_shape.dimensions(feature_index);
const int64 size_in_elements = ShapeUtil::ElementsIn(operand_shape);
auto elements_per_feature =
computation_->AddInstruction(HloInstruction::CreateConstant(
Literal::CreateR0<float>(size_in_elements / feature_count)));
auto elements_per_feature_literal =
Literal::CreateR0<float>(size_in_elements / feature_count);
TF_ASSIGN_OR_RETURN(elements_per_feature_literal,
elements_per_feature_literal->Convert(ptype));
auto elements_per_feature = computation_->AddInstruction(
HloInstruction::CreateConstant(std::move(elements_per_feature_literal)));
HloInstruction* scale = batch_norm->mutable_operand(1);
HloInstruction* offset = batch_norm->mutable_operand(2);
const Shape feature_shape = scale->shape();
auto zero_literal = Literal::CreateR0(0.0f);
TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype));
auto zero = computation_->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0(0.0f)));
HloInstruction::CreateConstant(std::move(zero_literal)));
auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon());
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
auto epsilon = computation_->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon())));
HloInstruction::CreateConstant(std::move(epsilon_literal)));
std::vector<int64> dimensions_without_feature;
@ -184,7 +192,7 @@ Status BatchNormRewriterVisitor::HandleBatchNormTraining(
HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index}));
HloComputation* add_reduce_computation =
GetScalarBinaryComputation(F32, HloOpcode::kAdd);
GetScalarBinaryComputation(ptype, HloOpcode::kAdd);
// X^2.
auto operand_squared =
@ -243,8 +251,10 @@ Status BatchNormRewriterVisitor::HandleBatchNormTraining(
computation_->AddInstruction(HloInstruction::CreateBinary(
operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon));
auto neg_half_literal = Literal::CreateR0(-0.5f);
TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype));
auto neg_half = computation_->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0(-0.5f)));
HloInstruction::CreateConstant(std::move(neg_half_literal)));
// 1 / Sqrt[Var[X] + epsilon].
auto rsqrt_var_add_epsilon =
@ -286,6 +296,7 @@ Status BatchNormRewriterVisitor::HandleBatchNormInference(
HloInstruction* operand = batch_norm->mutable_operand(0);
const Shape operand_shape = operand->shape();
int64 feature_index = batch_norm->feature_index();
PrimitiveType ptype = operand_shape.element_type();
HloInstruction* scale = batch_norm->mutable_operand(1);
HloInstruction* offset = batch_norm->mutable_operand(2);
@ -293,8 +304,10 @@ Status BatchNormRewriterVisitor::HandleBatchNormInference(
HloInstruction* var = batch_norm->mutable_operand(4);
const Shape feature_shape = scale->shape();
auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon());
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
auto epsilon = computation_->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon())));
HloInstruction::CreateConstant(std::move(epsilon_literal)));
std::vector<int64> dimensions_without_feature;
@ -321,8 +334,10 @@ Status BatchNormRewriterVisitor::HandleBatchNormInference(
computation_->AddInstruction(HloInstruction::CreateBinary(
operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon));
auto neg_half_literal = Literal::CreateR0(-0.5f);
TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype));
auto neg_half = computation_->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0(-0.5f)));
HloInstruction::CreateConstant(std::move(neg_half_literal)));
// 1 / Sqrt[Var[X] + epsilon].
auto rsqrt_var_add_epsilon =
@ -373,6 +388,7 @@ Status BatchNormRewriterVisitor::HandleBatchNormGrad(
HloInstruction* activation = batch_norm->mutable_operand(0);
const Shape activation_shape = activation->shape();
PrimitiveType ptype = activation_shape.element_type();
HloInstruction* scale = batch_norm->mutable_operand(1);
const Shape feature_shape = scale->shape();
HloInstruction* mean = batch_norm->mutable_operand(2);
@ -383,18 +399,27 @@ Status BatchNormRewriterVisitor::HandleBatchNormGrad(
const int64 size_in_elements = ShapeUtil::ElementsIn(activation_shape);
const int64 feature_count = activation_shape.dimensions(feature_index);
auto elements_per_feature =
computation_->AddInstruction(HloInstruction::CreateConstant(
Literal::CreateR0<float>(size_in_elements / feature_count)));
auto elements_per_feature_literal =
Literal::CreateR0<float>(size_in_elements / feature_count);
TF_ASSIGN_OR_RETURN(elements_per_feature_literal,
elements_per_feature_literal->Convert(ptype));
auto elements_per_feature = computation_->AddInstruction(
HloInstruction::CreateConstant(std::move(elements_per_feature_literal)));
auto zero_literal = Literal::CreateR0(0.0f);
TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype));
auto zero = computation_->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0(0.0f)));
HloInstruction::CreateConstant(std::move(zero_literal)));
auto neg_half_literal = Literal::CreateR0(-0.5f);
TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype));
auto neg_half = computation_->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0(-0.5f)));
HloInstruction::CreateConstant(std::move(neg_half_literal)));
auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon());
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
auto epsilon = computation_->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon())));
HloInstruction::CreateConstant(std::move(epsilon_literal)));
std::vector<int64> dimensions_without_feature;
@ -442,7 +467,7 @@ Status BatchNormRewriterVisitor::HandleBatchNormGrad(
grad_output, activation_minus_mean));
HloComputation* add_reduce_computation =
GetScalarBinaryComputation(F32, HloOpcode::kAdd);
GetScalarBinaryComputation(ptype, HloOpcode::kAdd);
// sum(Grad[Y] * (X - E[X])).
auto sum_grad_output_times_activiation_minus_mean =

View File

@ -197,28 +197,35 @@ void InitializeLLVMCommandLineOptions(const HloModuleConfig& config) {
class CollectProfileCandidates : public DfsHloVisitorWithDefault {
public:
static StatusOr<std::unordered_map<const HloInstruction*, size_t>>
GetCandidatesForComputation(HloComputation* computation) {
GetCandidatesForComputation(
HloComputation* computation,
const std::unordered_map<const HloInstruction*, int64>&
assigned_indices) {
std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx;
CollectProfileCandidates profile_candidates_for_computation(
&hlo_to_profile_idx);
&hlo_to_profile_idx, assigned_indices);
TF_RETURN_IF_ERROR(
computation->Accept(&profile_candidates_for_computation));
return hlo_to_profile_idx;
}
private:
explicit CollectProfileCandidates(
std::unordered_map<const HloInstruction*, size_t>* hlo_to_profile_idx)
: hlo_to_profile_idx_(hlo_to_profile_idx) {}
CollectProfileCandidates(
std::unordered_map<const HloInstruction*, size_t>* hlo_to_profile_idx,
const std::unordered_map<const HloInstruction*, int64>& assigned_indices)
: hlo_to_profile_idx_(hlo_to_profile_idx),
assigned_indices_(assigned_indices) {}
Status DefaultAction(HloInstruction* hlo_instruction) override {
hlo_to_profile_idx_->insert({hlo_instruction, hlo_to_profile_idx_->size()});
hlo_to_profile_idx_->insert(
{hlo_instruction, FindOrDie(assigned_indices_, hlo_instruction)});
return Status::OK();
}
Status HandleCall(HloInstruction* call) override {
TF_RETURN_IF_ERROR(DefaultAction(call));
CollectProfileCandidates candidates_for_call(hlo_to_profile_idx_);
CollectProfileCandidates candidates_for_call(hlo_to_profile_idx_,
assigned_indices_);
TF_RETURN_IF_ERROR(call->to_apply()->Accept(&candidates_for_call));
return Status::OK();
}
@ -232,17 +239,20 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault {
Status HandleWhile(HloInstruction* xla_while) override {
TF_RETURN_IF_ERROR(DefaultAction(xla_while));
CollectProfileCandidates candidates_for_condition(hlo_to_profile_idx_);
CollectProfileCandidates candidates_for_condition(hlo_to_profile_idx_,
assigned_indices_);
TF_RETURN_IF_ERROR(
xla_while->while_condition()->Accept(&candidates_for_condition));
CollectProfileCandidates candidates_for_body(hlo_to_profile_idx_);
CollectProfileCandidates candidates_for_body(hlo_to_profile_idx_,
assigned_indices_);
TF_RETURN_IF_ERROR(xla_while->while_body()->Accept(&candidates_for_body));
return Status::OK();
}
std::unordered_map<const HloInstruction*, size_t>* hlo_to_profile_idx_;
const std::unordered_map<const HloInstruction*, int64>& assigned_indices_;
};
} // namespace
@ -475,10 +485,27 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
HloComputation* computation = module->entry_computation();
std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx;
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map;
std::unique_ptr<HloProfilePrinter> hlo_profile_printer;
if (module->config().hlo_profiling_enabled()) {
hlo_profile_index_map = MakeUnique<HloProfileIndexMap>(*module);
TF_ASSIGN_OR_RETURN(
hlo_to_profile_idx,
CollectProfileCandidates::GetCandidatesForComputation(computation));
CollectProfileCandidates::GetCandidatesForComputation(
computation, hlo_profile_index_map->instruction_to_profile_idx()));
auto shape_size_bytes = [](const Shape& shape) {
// On the cpu, opaques are pointers.
if (ShapeUtil::IsOpaque(shape)) {
return static_cast<int64>(sizeof(void*));
}
return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
};
HloCostAnalysis cost_analysis(shape_size_bytes);
hlo_profile_printer =
CreateHloProfilePrinter(*hlo_profile_index_map, cost_analysis);
}
std::unique_ptr<Executable> cpu_executable;
@ -544,8 +571,16 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
parallel_computations.emplace(to_apply, instruction);
}
// We always profile the entire computation as a whole, even if hlo
// profiling is disabled. When hlo profiling is diabled, we pass in a
// profile counter array of just one element, which corresponds to the whole
// computation.
size_t entry_computation_profile_idx =
hlo_profile_index_map ? hlo_profile_index_map->GetProfileIndexFor(
*module->entry_computation())
: 0;
IrEmitter ir_emitter(*module, *assignment, llvm_module.get(),
hlo_to_profile_idx, hlo_to_profile_idx.size(),
hlo_to_profile_idx, entry_computation_profile_idx,
jit->target_machine(), jit->external_constant_pool());
std::unique_ptr<HloInstructionMap<string>> function_names(
@ -586,8 +621,8 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
jit->AddModule(std::move(llvm_module));
cpu_executable.reset(new ParallelCpuExecutable(
std::move(jit), std::move(assignment), std::move(module),
std::move(function_names), std::move(hlo_to_profile_idx),
std::move(aligned_constants)));
std::move(function_names), std::move(aligned_constants),
std::move(hlo_profile_printer), std::move(hlo_profile_index_map)));
if (embed_ir_in_executable) {
static_cast<CpuExecutable&>(*cpu_executable)
@ -620,12 +655,22 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory(
proto, xla_dump_hlo_proto_to, module->name()));
}
// We always profile the entire computation as a whole, even if hlo
// profiling is disabled. When hlo profiling is diabled, we pass in a
// profile counter array of just one element, which corresponds to the whole
// computation.
size_t entry_computation_profile_idx =
hlo_profile_index_map ? hlo_profile_index_map->GetProfileIndexFor(
*module->entry_computation())
: 0;
// Each computation is a single function. Emit all embedded computations
// before the entry computation. The order of computations returned from
// GetEmbeddedComputations guarantees that a called computation occurs
// before a caller computation.
IrEmitter ir_emitter(*module, *assignment, llvm_module.get(),
hlo_to_profile_idx, hlo_to_profile_idx.size(),
hlo_to_profile_idx, entry_computation_profile_idx,
jit->target_machine(), jit->external_constant_pool());
for (auto embedded_computation :
@ -659,7 +704,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
jit->AddModule(std::move(llvm_module));
cpu_executable.reset(new CpuExecutable(
std::move(jit), std::move(assignment), std::move(module), function_name,
std::move(hlo_to_profile_idx)));
std::move(hlo_profile_printer), std::move(hlo_profile_index_map)));
if (embed_ir_in_executable) {
static_cast<CpuExecutable&>(*cpu_executable)

View File

@ -43,6 +43,7 @@ limitations under the License.
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/stream_executor/host/host_stream.h"
namespace se = ::perftools::gputools;
@ -54,11 +55,12 @@ CpuExecutable::CpuExecutable(
std::unique_ptr<const BufferAssignment> assignment,
std::unique_ptr<const HloModule> hlo_module,
const string& entry_function_name,
std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx)
: Executable(std::move(hlo_module)),
std::unique_ptr<HloProfilePrinter> hlo_profile_printer,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map)
: Executable(std::move(hlo_module), std::move(hlo_profile_printer),
std::move(hlo_profile_index_map)),
jit_(std::move(jit)),
assignment_(std::move(assignment)),
hlo_to_profile_idx_(std::move(hlo_to_profile_idx)) {
assignment_(std::move(assignment)) {
// Resolve symbols in the constructor rather than at execution time to avoid
// races because FindSymbol is not thread safe.
llvm::JITSymbol sym = jit_->FindSymbol(entry_function_name);
@ -182,9 +184,16 @@ Status CpuExecutable::ExecuteComputeFunction(
uint64 start_micros = tensorflow::Env::Default()->NowMicros();
// Allocate profiling counters for each hlo instruction that we would like to
// profile. Allocate an additional profile counter for the entire
// computation.
std::vector<uint64> profile_counters(hlo_to_profile_idx_.size() + 1);
// profile. Even when not Hlo profiling, we allocate a counter for the entire
// computation, which we use to update ExecutionProfile below.
std::vector<int64>* profile_counters = nullptr;
std::vector<int64> profile_counter_for_entry_computation;
if (hlo_execution_profile) {
profile_counters = hlo_execution_profile->mutable_profile_counters();
} else {
profile_counters = &profile_counter_for_entry_computation;
profile_counter_for_entry_computation.push_back(0);
}
// Call the computation function following the calling convention.
std::vector<void*> buffer_pointers;
@ -199,7 +208,7 @@ Status CpuExecutable::ExecuteComputeFunction(
VLOG(3) << tensorflow::strings::Printf(
" func(void* result, void* params[%zu], void* temps[%zu], "
"uint64 profile_counters[%zu])",
args_array.size(), buffer_pointers.size(), profile_counters.size());
args_array.size(), buffer_pointers.size(), profile_counters->size());
VLOG(3) << tensorflow::strings::Printf(" result = %p", result_buffer);
auto ptr_printer = [](string* out, const void* p) {
tensorflow::strings::StrAppend(out, tensorflow::strings::Printf("%p", p));
@ -211,11 +220,11 @@ Status CpuExecutable::ExecuteComputeFunction(
" temps = [%s]",
tensorflow::str_util::Join(buffer_pointers, ", ", ptr_printer).c_str());
VLOG(3) << tensorflow::strings::Printf(" profile_counters = %p",
profile_counters.data());
profile_counters->data());
}
compute_function_(result_buffer, run_options, args_array.data(),
buffer_pointers.data(), profile_counters.data());
buffer_pointers.data(), profile_counters->data());
uint64 end_micros = tensorflow::Env::Default()->NowMicros();
@ -224,20 +233,46 @@ Status CpuExecutable::ExecuteComputeFunction(
const double nanoseconds = (end_micros - start_micros) * 1000.0;
execution_profile_.set_compute_time_ns(std::max(nanoseconds, 1.0));
// The last profile counter is used for the computation as a whole.
execution_profile_.set_compute_cycle_count(profile_counters.back());
}
if (hlo_execution_profile != nullptr) {
hlo_execution_profile->set_total_cycles_executed(
*module().entry_computation(), profile_counters.back());
for (auto hlo_prof_idx : hlo_to_profile_idx_) {
const HloInstruction* hlo = hlo_prof_idx.first;
uint64 cycles_taken = profile_counters[hlo_prof_idx.second];
hlo_execution_profile->SetCyclesTakenBy(hlo, cycles_taken);
if (hlo_execution_profile) {
execution_profile_.set_compute_cycle_count(
hlo_execution_profile->total_cycles_executed(
*module().entry_computation()));
} else {
execution_profile_.set_compute_cycle_count(profile_counters->back());
}
}
return Status::OK();
}
static void LogLiveAddresses(
const std::unordered_set<const void*>& marked_addresses) {
VLOG(3) << "Live addresses in output marking found "
<< marked_addresses.size() << " addresses:\n"
<< tensorflow::str_util::Join(
marked_addresses, ", ", [](string* out, const void* address) {
tensorflow::strings::StrAppend(
out, tensorflow::strings::Printf("%p", address));
});
}
static Status DeallocateTempBuffers(
DeviceMemoryAllocator* allocator, se::Stream* stream,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers,
const std::unordered_set<const void*>& marked_addresses) {
// Keep those marked live because they are referenced by the output of the
// computation and are needed by the service. They will be deallocated by the
// service.
for (size_t i = 0; i < buffers.size(); ++i) {
se::DeviceMemoryBase alloc = buffers[i];
if (marked_addresses.count(alloc.opaque()) == 0 && !alloc.is_null()) {
VLOG(3) << "CpuExecutable deallocating buffer #" << i << " ["
<< alloc.opaque() << "]";
TF_RETURN_IF_ERROR(
allocator->Deallocate(stream->parent()->device_ordinal(), &alloc));
}
}
return Status::OK();
}
@ -263,26 +298,9 @@ StatusOr<perftools::gputools::DeviceMemoryBase> CpuExecutable::ExecuteOnStream(
MarkLiveAddressesInOutput(top_level_output.opaque(), result_shape(),
&marked_addresses);
VLOG(3) << "Live addresses in output marking found "
<< marked_addresses.size() << " addresses:\n"
<< tensorflow::str_util::Join(
marked_addresses, ", ", [](string* out, const void* address) {
tensorflow::strings::StrAppend(
out, tensorflow::strings::Printf("%p", address));
});
// Computation is done - deallocate temp buffers. Keep those marked live
// because they are referenced by the output of the computation and are needed
// by the service. They will be deallocated by the service.
for (size_t i = 0; i < buffers.size(); ++i) {
se::DeviceMemoryBase alloc = buffers[i];
if (marked_addresses.count(alloc.opaque()) == 0 && !alloc.is_null()) {
VLOG(3) << "CpuExecutable deallocating buffer #" << i << " ["
<< alloc.opaque() << "]";
TF_RETURN_IF_ERROR(memory_allocator->Deallocate(
stream->parent()->device_ordinal(), &alloc));
}
}
LogLiveAddresses(marked_addresses);
TF_RETURN_IF_ERROR(DeallocateTempBuffers(memory_allocator, stream, buffers,
marked_addresses));
return top_level_output;
}
@ -360,9 +378,44 @@ StatusOr<perftools::gputools::DeviceMemoryBase>
CpuExecutable::ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments) {
// TODO(b/30671675): Implement asynchronous execution mode.
return Unimplemented(
"Asynchronous execution on stream is not yet supported on CPU.");
if (hlo_profiling_enabled()) {
return Unimplemented(
"Asynchronous execution on stream with hlo profiling is not yet "
"supported on CPU.");
}
auto* host_stream = dynamic_cast<perftools::gputools::host::HostStream*>(
run_options->stream()->implementation());
se::Stream* stream = run_options->stream();
DeviceMemoryAllocator* memory_allocator = run_options->allocator();
std::vector<se::DeviceMemoryBase> buffers(assignment_->Allocations().size());
TF_RETURN_IF_ERROR(AllocateBuffers(
memory_allocator, stream->parent()->device_ordinal(), &buffers));
// Mark the buffers that are actually live (used in the output) when the
// computation finishes executing.
std::unordered_set<const void*> marked_addresses;
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice,
assignment_->GetUniqueTopLevelOutputSlice());
se::DeviceMemoryBase top_level_output = buffers[result_slice.index()];
MarkLiveAddressesInOutput(top_level_output.opaque(), result_shape(),
&marked_addresses);
LogLiveAddresses(marked_addresses);
host_stream->EnqueueTask([this, run_options, arguments, buffers,
marked_addresses, memory_allocator, stream]() {
// Failing a CHECK here is not great, but I don't see an obvious way to
// return a failed Status asynchronously.
TF_CHECK_OK(ExecuteComputeFunction(&run_options->run_options(), arguments,
buffers,
/*hlo_execution_profile=*/nullptr));
TF_CHECK_OK(DeallocateTempBuffers(memory_allocator, stream, buffers,
marked_addresses));
});
return top_level_output;
}
/*static*/ int64 CpuExecutable::ShapeSizeBytes(const Shape& shape) {
@ -378,9 +431,5 @@ const PointsToSet& CpuExecutable::GetRootPointsToSet() const {
module().entry_computation()->root_instruction());
}
std::unique_ptr<HloCostAnalysis> CpuExecutable::CreateCostAnalysis() const {
return MakeUnique<HloCostAnalysis>(ShapeSizeBytes);
}
} // namespace cpu
} // namespace xla

View File

@ -47,12 +47,12 @@ namespace cpu {
// architecture, so JIT-ed code and host code share the same ABI.
class CpuExecutable : public Executable {
public:
CpuExecutable(
std::unique_ptr<SimpleOrcJIT> jit,
std::unique_ptr<const BufferAssignment> assignment,
std::unique_ptr<const HloModule> hlo_module,
const string& entry_function_name,
std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx);
CpuExecutable(std::unique_ptr<SimpleOrcJIT> jit,
std::unique_ptr<const BufferAssignment> assignment,
std::unique_ptr<const HloModule> hlo_module,
const string& entry_function_name,
std::unique_ptr<HloProfilePrinter> hlo_profile_printer,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map);
~CpuExecutable() override {}
StatusOr<perftools::gputools::DeviceMemoryBase> ExecuteOnStream(
@ -85,12 +85,10 @@ class CpuExecutable : public Executable {
static int64 ShapeSizeBytes(const Shape& shape);
std::unique_ptr<HloCostAnalysis> CreateCostAnalysis() const override;
// Type of the computation function we expect in the JIT.
using ComputeFunctionType = void (*)(
void* /*result*/, const ExecutableRunOptions* /*run_options*/,
const void** /*args*/, void** /*temps*/, uint64* /*profile_counters*/);
const void** /*args*/, void** /*temps*/, int64* /*profile_counters*/);
const ComputeFunctionType& compute_function() const {
return compute_function_;
@ -145,9 +143,6 @@ class CpuExecutable : public Executable {
// Entry function name for the computation.
const string entry_function_name_;
// Maps HLOs to their index into the profile counter array.
const std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx_;
TF_DISALLOW_COPY_AND_ASSIGN(CpuExecutable);
};

View File

@ -59,19 +59,20 @@ ParallelCpuExecutable::ParallelCpuExecutable(
std::unique_ptr<const BufferAssignment> assignment,
std::unique_ptr<const HloModule> hlo_module,
std::unique_ptr<const HloInstructionMap<string>> function_names,
std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx,
std::unordered_map<const HloInstruction*, std::unique_ptr<unsigned char[]>>
aligned_constants)
: Executable(std::move(hlo_module)),
aligned_constants,
std::unique_ptr<HloProfilePrinter> hlo_profile_printer,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map)
: Executable(std::move(hlo_module), std::move(hlo_profile_printer),
std::move(hlo_profile_index_map)),
jit_(std::move(jit)),
assignment_(std::move(assignment)),
function_names_(std::move(function_names)),
hlo_to_profile_idx_(std::move(hlo_to_profile_idx)),
aligned_constants_(std::move(aligned_constants)) {}
// Type of the computation function we expect in the JIT.
using ComputeFunctionType = void (*)(void*, const void*, const void**, void**,
int64*, uint64*);
int64*, int64*);
// Given a pointer to an output buffer (following the CPU JIT calling
// conventions), mark addresses that are "live". The initial pointer itself is
@ -106,7 +107,7 @@ class Executor {
const ServiceExecutableRunOptions* run_options,
std::list<HloInstruction*>* pending,
HloInstructionMap<const void*>* results, void** temps_array,
uint64* profile_counters_array, const BufferAssignment* assignment)
int64* profile_counters_array, const BufferAssignment* assignment)
: functions_(functions),
run_options_(run_options),
pending_(pending),
@ -147,7 +148,7 @@ class Executor {
std::list<HloInstruction*>* pending_;
HloInstructionMap<const void*>* results_;
void** temps_array_;
uint64* profile_counters_array_;
int64* profile_counters_array_;
tensorflow::thread::ThreadPool* thread_pool_;
const BufferAssignment* assignment_;
@ -389,9 +390,11 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions(
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers,
HloExecutionProfile* hlo_execution_profile) {
// Allocate profiling counters for each hlo instruction that we would like to
// profile. Allocate an additional profile counter for the entire
// computation.
std::vector<uint64> profile_counters(hlo_to_profile_idx_.size() + 1);
// profile.
std::vector<int64>* profile_counters = nullptr;
if (hlo_execution_profile) {
profile_counters = hlo_execution_profile->mutable_profile_counters();
}
std::vector<void*> buffer_pointers;
buffer_pointers.reserve(buffers.size());
@ -441,9 +444,9 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions(
// For example, if we expect a library conv/matmul call to run at max
// concurrency, we should not dispatch runnable instructions until the
// library call is finished (to avoid expensive cache invalidation).
Executor executor(functions, run_options, &pending, &results,
buffer_pointers.data(), profile_counters.data(),
assignment_.get());
Executor executor(
functions, run_options, &pending, &results, buffer_pointers.data(),
profile_counters ? profile_counters->data() : nullptr, assignment_.get());
TF_RETURN_IF_ERROR(executor.Run());
@ -453,18 +456,6 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions(
tensorflow::mutex_lock lock(mutex_);
double nanoseconds = (end_micros - start_micros) * 1000.0;
execution_profile_.set_compute_time_ns(std::max(nanoseconds, 1.0));
// The last profile counter is used for the computation as a whole.
execution_profile_.set_compute_cycle_count(profile_counters.back());
}
if (hlo_execution_profile != nullptr) {
hlo_execution_profile->set_total_cycles_executed(entry_computation,
profile_counters.back());
for (auto hlo_prof_idx : hlo_to_profile_idx_) {
const HloInstruction* hlo = hlo_prof_idx.first;
uint64 cycles_taken = profile_counters[hlo_prof_idx.second];
hlo_execution_profile->SetCyclesTakenBy(hlo, cycles_taken);
}
}
return Status::OK();
@ -618,10 +609,5 @@ const PointsToSet& ParallelCpuExecutable::GetRootPointsToSet() const {
module().entry_computation()->root_instruction());
}
std::unique_ptr<HloCostAnalysis> ParallelCpuExecutable::CreateCostAnalysis()
const {
return MakeUnique<HloCostAnalysis>(ShapeSizeBytes);
}
} // namespace cpu
} // namespace xla

View File

@ -52,10 +52,11 @@ class ParallelCpuExecutable : public Executable {
std::unique_ptr<const BufferAssignment> assignment,
std::unique_ptr<const HloModule> hlo_module,
std::unique_ptr<const HloInstructionMap<string>> function_names,
std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx,
std::unordered_map<const HloInstruction*,
std::unique_ptr<unsigned char[]>>
aligned_constants);
aligned_constants,
std::unique_ptr<HloProfilePrinter> hlo_profile_printer,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map);
~ParallelCpuExecutable() override {}
StatusOr<perftools::gputools::DeviceMemoryBase> ExecuteOnStream(
@ -95,8 +96,6 @@ class ParallelCpuExecutable : public Executable {
"Equality test on CPU parallel executable is not implemented.");
}
std::unique_ptr<HloCostAnalysis> CreateCostAnalysis() const override;
private:
// Allocate buffers required for execution and assign them to the elements of
// "buffers". "buffers" should be sized to the number of buffers in buffer
@ -143,9 +142,6 @@ class ParallelCpuExecutable : public Executable {
// Map containing the JITted function names for each HLO instruction.
const std::unique_ptr<const HloInstructionMap<string>> function_names_;
// Maps HLOs to their index into the profile counter array.
const std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx_;
// Map from HLO Constant instructions to a pointer to their literal data.
// The data stored in the protocol buffer might be insufficiently aligned,
// we create a sufficiently aligned copy and store it in this map.

View File

@ -44,8 +44,15 @@ namespace xla {
// interface that is used for launching compiled programs across platforms.
class Executable {
public:
explicit Executable(std::unique_ptr<const HloModule> hlo_module)
: hlo_module_(std::move(hlo_module)) {}
explicit Executable(std::unique_ptr<const HloModule> hlo_module,
std::unique_ptr<HloProfilePrinter> hlo_profile_printer,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map)
: hlo_module_(std::move(hlo_module)),
hlo_profile_printer_(std::move(hlo_profile_printer)),
hlo_profile_index_map_(std::move(hlo_profile_index_map)) {
CHECK_EQ(hlo_profile_printer_.get() == nullptr,
hlo_profile_index_map_.get() == nullptr);
}
virtual ~Executable() {}
// Enqueues the compilation result on the provided stream, passing the given
@ -123,12 +130,20 @@ class Executable {
"Equality test on this executable is not implemented.");
}
const HloProfilePrinter& hlo_profile_printer() const {
CHECK(hlo_profiling_enabled());
return *hlo_profile_printer_;
}
const HloProfileIndexMap& hlo_profile_index_map() const {
CHECK(hlo_profiling_enabled());
return *hlo_profile_index_map_;
}
// Returns whether this executable was compiled with HLO profilings support
// enabled. If not, the caller should not expect an hlo_execution_profile
// passed to ExecuteOnStream above to be populated during execution.
bool hlo_profiling_enabled() const {
return hlo_module_->config().hlo_profiling_enabled();
}
bool hlo_profiling_enabled() const { return hlo_profile_printer_ != nullptr; }
const HloModule& module() const { return *hlo_module_; }
@ -160,10 +175,6 @@ class Executable {
static Status DumpToDirectory(const string& directory_path, string filename,
const SessionModule& session_module);
// Returns a cost analysis object appropriate for the platform on which this
// executable can run.
virtual std::unique_ptr<HloCostAnalysis> CreateCostAnalysis() const = 0;
protected:
mutable tensorflow::mutex mutex_;
@ -181,6 +192,9 @@ class Executable {
// Execution count, used to generate a unique filename for each dumped
// execution.
int64 execution_count_ = 0;
std::unique_ptr<HloProfilePrinter> hlo_profile_printer_;
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map_;
};
template <typename ReturnT, typename ArgT>
@ -200,7 +214,8 @@ StatusOr<ReturnT> Executable::ExecuteOnStreamWrapper(
std::unique_ptr<HloExecutionProfile> profile_ptr =
module_config().debug_options().xla_hlo_profile() &&
hlo_profiling_enabled()
? MakeUnique<HloExecutionProfile>(module(), *CreateCostAnalysis())
? MakeUnique<HloExecutionProfile>(&hlo_profile_printer(),
&hlo_profile_index_map())
: nullptr;
auto return_value =

View File

@ -465,10 +465,20 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
VLOG(2) << "Printing the thunk schedule...";
XLA_VLOG_LINES(2, thunk_schedule->ToString());
auto* gpu_executable =
new GpuExecutable(ptx, cubin, {cc_major, cc_minor},
std::move(thunk_schedule), std::move(module),
std::move(buffer_assignment), ShapeSizeBytesFunction());
std::unique_ptr<HloProfileIndexMap> profile_index_map;
std::unique_ptr<HloProfilePrinter> profile_printer;
if (module->config().hlo_profiling_enabled()) {
HloCostAnalysis cost_analysis(ShapeSizeBytesFunction());
profile_index_map = MakeUnique<HloProfileIndexMap>(*module);
profile_printer =
CreateHloProfilePrinter(*profile_index_map, cost_analysis);
}
auto* gpu_executable = new GpuExecutable(
ptx, cubin, {cc_major, cc_minor}, std::move(thunk_schedule),
std::move(module), std::move(buffer_assignment),
std::move(profile_printer), std::move(profile_index_map));
if (embed_ir_in_executable) {
DCHECK_NE("", ir_module_string_before_opt);
gpu_executable->set_ir_module_string(ir_module_string_before_opt);

View File

@ -113,14 +113,15 @@ GpuExecutable::GpuExecutable(
std::unique_ptr<const ThunkSchedule> thunk_schedule,
std::unique_ptr<const HloModule> hlo_module,
std::unique_ptr<const BufferAssignment> assignment,
HloCostAnalysis::ShapeSizeFunction shape_size_function)
: Executable(std::move(hlo_module)),
std::unique_ptr<HloProfilePrinter> hlo_profile_printer,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map)
: Executable(std::move(hlo_module), std::move(hlo_profile_printer),
std::move(hlo_profile_index_map)),
ptx_(ptx),
cubin_(cubin),
compute_capability_(compute_capability),
thunk_schedule_(std::move(thunk_schedule)),
assignment_(std::move(assignment)),
shape_size_function_(std::move(shape_size_function)) {}
assignment_(std::move(assignment)) {}
Status GpuExecutable::ExecuteThunks(
const ServiceExecutableRunOptions* run_options,
@ -358,9 +359,5 @@ const PointsToSet& GpuExecutable::GetRootPointsToSet() const {
module().entry_computation()->root_instruction());
}
std::unique_ptr<HloCostAnalysis> GpuExecutable::CreateCostAnalysis() const {
return MakeUnique<HloCostAnalysis>(shape_size_function_);
}
} // namespace gpu
} // namespace xla

View File

@ -54,7 +54,8 @@ class GpuExecutable : public Executable {
std::unique_ptr<const ThunkSchedule> thunk_schedule,
std::unique_ptr<const HloModule> hlo_module,
std::unique_ptr<const BufferAssignment> assignment,
HloCostAnalysis::ShapeSizeFunction shape_size_function);
std::unique_ptr<HloProfilePrinter> hlo_profile_printer,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map);
// This should be called after set_ir_module_string.
const string& ir_module_string() const { return ir_module_string_; }
@ -95,8 +96,6 @@ class GpuExecutable : public Executable {
return Unimplemented("Equality test on GPU executable is not implemented.");
}
std::unique_ptr<HloCostAnalysis> CreateCostAnalysis() const override;
private:
// If `block_host_until_done` is false, execution will not block the host
// until the kernels have completed. This is used as an optimization for
@ -140,9 +139,6 @@ class GpuExecutable : public Executable {
// memory for every output/temp buffers.
const std::unique_ptr<const BufferAssignment> assignment_;
// Function to compute the size of a given Shape, in bytes.
const HloCostAnalysis::ShapeSizeFunction shape_size_function_;
TF_DISALLOW_COPY_AND_ASSIGN(GpuExecutable);
};

View File

@ -40,7 +40,7 @@ HloProfileIndexMap::HloProfileIndexMap(const HloModule& module) {
}
}
static HloProfilePrinter CreateOwnedHloProfilePrinter(
std::unique_ptr<HloProfilePrinter> CreateHloProfilePrinter(
const HloProfileIndexMap& hlo_profile_index_map,
const HloCostAnalysis& cost_analysis) {
using HloComputationInfo = HloProfilePrinter::HloComputationInfo;
@ -108,15 +108,15 @@ static HloProfilePrinter CreateOwnedHloProfilePrinter(
delete[] computation_infos;
};
return HloProfilePrinter(computation_infos,
hlo_profile_index_map.computation_count(), deleter);
return MakeUnique<HloProfilePrinter>(
computation_infos, hlo_profile_index_map.computation_count(), deleter);
}
HloExecutionProfile::HloExecutionProfile(const HloModule& module,
const HloCostAnalysis& cost_analysis)
: hlo_profile_index_map_(module),
hlo_profile_printer_(
CreateOwnedHloProfilePrinter(hlo_profile_index_map_, cost_analysis)),
HloExecutionProfile::HloExecutionProfile(
const HloProfilePrinter* hlo_profile_printer,
const HloProfileIndexMap* hlo_profile_index_map)
: hlo_profile_printer_(*hlo_profile_printer),
hlo_profile_index_map_(*hlo_profile_index_map),
profile_counters_(
/*count*/ hlo_profile_index_map_.total_count(),
/*value*/ 0) {}
@ -131,10 +131,4 @@ uint64 HloExecutionProfile::GetCyclesTakenBy(const HloInstruction& hlo) const {
return profile_counters_[hlo_profile_index_map_.GetProfileIndexFor(hlo)];
}
string HloExecutionProfile::ToString(
const DeviceDescription& device_description) const {
return hlo_profile_printer_.ToString(profile_counters_.data(),
device_description.clock_rate_ghz());
}
} // namespace xla

View File

@ -77,6 +77,11 @@ class HloProfileIndexMap {
std::unordered_map<const HloComputation*, int64> computation_to_profile_idx_;
};
// Create an instance of `HloProfilePrinter` that owns its memory.
std::unique_ptr<HloProfilePrinter> CreateHloProfilePrinter(
const HloProfileIndexMap& hlo_profile_index_map,
const HloCostAnalysis& cost_analysis);
// Describes how much time each HLO operation took.
//
// Each HloComputation takes a certain number of cycles. This class helps break
@ -85,8 +90,8 @@ class HloExecutionProfile {
public:
using DeviceDescription = perftools::gputools::DeviceDescription;
HloExecutionProfile(const HloModule& module,
const HloCostAnalysis& cost_analysis);
HloExecutionProfile(const HloProfilePrinter* hlo_profile_printer,
const HloProfileIndexMap* hlo_profile_index_map);
// Record how many cycles this HLO took to execute.
void SetCyclesTakenBy(const HloInstruction* hlo, uint64 cycles_taken);
@ -114,15 +119,16 @@ class HloExecutionProfile {
// for the operations in a given computation. Returns an empty string if it
// wasn't possible to generate a printable version. cost_analysis should be a
// clean analysis that can be used to visit the computation.
string ToString(const DeviceDescription& device_description) const;
string ToString(const DeviceDescription& device_description) const {
return hlo_profile_printer_.ToString(profile_counters_.data(),
device_description.clock_rate_ghz());
}
std::vector<int64>* mutable_profile_counters() { return &profile_counters_; }
private:
// hlo_profile_index_map_ maps an Hlo entity (computation or instruction) to
// an index in profile_counters_.
HloProfileIndexMap hlo_profile_index_map_;
// Used to print profile_counters_ in a human readable form.
HloProfilePrinter hlo_profile_printer_;
const HloProfilePrinter& hlo_profile_printer_;
const HloProfileIndexMap& hlo_profile_index_map_;
// Stores per-Hlo profile counters. This is the only thing that changes when
// we execute an XLA computation.

View File

@ -72,7 +72,11 @@ TEST_F(HloExecutionProfileTest, Basic) {
};
HloCostAnalysis cost_analysis(shape_size_function);
HloExecutionProfile execution_profile(*hlo_module, cost_analysis);
HloProfileIndexMap profile_index_map(*hlo_module);
std::unique_ptr<HloProfilePrinter> profile_printer =
CreateHloProfilePrinter(profile_index_map, cost_analysis);
HloExecutionProfile execution_profile(profile_printer.get(),
&profile_index_map);
const int64 add_cycles = 1000;
const int64 dot_cycles = 4000;

View File

@ -42,7 +42,8 @@ namespace sep = ::perftools::gputools::interpreter;
InterpreterExecutable::InterpreterExecutable(
std::unique_ptr<const HloModule> hlo_module)
: Executable(std::move(hlo_module)) {}
: Executable(std::move(hlo_module), /*hlo_profile_printer=*/nullptr,
/*hlo_profile_index_map=*/nullptr) {}
InterpreterExecutable::~InterpreterExecutable() {}
@ -156,10 +157,5 @@ StatusOr<se::DeviceMemoryBase> InterpreterExecutable::ExecuteAsyncOnStream(
return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
}
std::unique_ptr<HloCostAnalysis> InterpreterExecutable::CreateCostAnalysis()
const {
return MakeUnique<HloCostAnalysis>(ShapeSizeBytes);
}
} // namespace interpreter
} // namespace xla

View File

@ -61,8 +61,6 @@ class InterpreterExecutable : public Executable {
static int64 ShapeSizeBytes(const Shape& shape);
std::unique_ptr<HloCostAnalysis> CreateCostAnalysis() const override;
private:
TF_DISALLOW_COPY_AND_ASSIGN(InterpreterExecutable);
};

View File

@ -575,12 +575,13 @@ Service::ExecuteParallelAndRegisterResult(
// profile.
for (auto& index_to_profiled_stream : index_to_profiled_streams) {
int64 device = index_to_profiled_stream.first;
auto& module = executables[device]->module();
se::Stream* stream = index_to_profiled_stream.second;
HloExecutionProfile hlo_profile(module,
*executables[device]->CreateCostAnalysis());
TF_RETURN_IF_ERROR(executables[device]->PopulateExecutionProfile(
&hlo_profile, stream->parent()));
Executable* executable = executables[device];
const HloModule& module = executable->module();
HloExecutionProfile hlo_profile(&executable->hlo_profile_printer(),
&executable->hlo_profile_index_map());
TF_RETURN_IF_ERROR(
executable->PopulateExecutionProfile(&hlo_profile, stream->parent()));
XLA_LOG_LINES(
tensorflow::INFO,
hlo_profile.ToString(streams[0]->parent()->GetDeviceDescription()));

View File

@ -773,6 +773,11 @@ xla_test(
xla_test(
name = "bfloat16_test",
srcs = ["bfloat16_test.cc"],
blacklisted_backends = [
"cpu",
"cpu_parallel",
"gpu",
],
shard_count = 40,
deps = [
":test_utils",
@ -1343,6 +1348,7 @@ xla_test(
srcs = ["client_test.cc"],
deps = [
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",

View File

@ -51,8 +51,7 @@ class Bfloat16Test : public ClientLibraryTestBase {
const ErrorSpec error_spec_{0.001, 0.001};
};
XLA_TEST_F(Bfloat16Test, DISABLED_ON_GPU(DISABLED_ON_CPU_PARALLEL(
DISABLED_ON_CPU(ScalarOperation)))) {
XLA_TEST_F(Bfloat16Test, ScalarOperation) {
ComputationBuilder builder(client_, TestName());
auto x = builder.ConstantR0<bfloat16>(static_cast<bfloat16>(2.0f));
auto y = builder.ConstantR0<bfloat16>(static_cast<bfloat16>(1.0f));
@ -62,8 +61,7 @@ XLA_TEST_F(Bfloat16Test, DISABLED_ON_GPU(DISABLED_ON_CPU_PARALLEL(
error_spec_);
}
XLA_TEST_F(Bfloat16Test, DISABLED_ON_GPU(DISABLED_ON_CPU_PARALLEL(
DISABLED_ON_CPU(NegateScalarF16)))) {
XLA_TEST_F(Bfloat16Test, NegateScalarF16) {
ComputationBuilder builder(client_, TestName());
builder.Neg(builder.ConstantR0<bfloat16>(static_cast<bfloat16>(2.1f)));
@ -71,5 +69,83 @@ XLA_TEST_F(Bfloat16Test, DISABLED_ON_GPU(DISABLED_ON_CPU_PARALLEL(
error_spec_);
}
XLA_TEST_F(Bfloat16Test, BatchNormTraining) {
const int kFeatureIndex = 2;
ComputationBuilder builder(client_, TestName());
auto operand = builder.ConstantR4FromArray4D<bfloat16>(
{{{{static_cast<bfloat16>(1.f)}, {static_cast<bfloat16>(2.f)}},
{{static_cast<bfloat16>(3.f)}, {static_cast<bfloat16>(4.f)}}},
{{{static_cast<bfloat16>(5.f)}, {static_cast<bfloat16>(6.f)}},
{{static_cast<bfloat16>(7.f)}, {static_cast<bfloat16>(8.f)}}}});
auto scale = builder.ConstantR1<bfloat16>(
{static_cast<bfloat16>(2.0f), static_cast<bfloat16>(3.0f)});
auto offset = builder.ConstantR1<bfloat16>(
{static_cast<bfloat16>(1.0f), static_cast<bfloat16>(2.0f)});
auto tuple = builder.BatchNormTraining(operand, scale, offset,
/*epsilon=*/0.001, kFeatureIndex);
auto expected = *Literal::MakeTuple(
{Literal::CreateR4<bfloat16>(
{{{{static_cast<bfloat16>(-1.7f)}, {static_cast<bfloat16>(-2.04f)}},
{{static_cast<bfloat16>(0.105f)}, {static_cast<bfloat16>(0.65f)}}},
{{{static_cast<bfloat16>(1.89f)}, {static_cast<bfloat16>(3.35f)}},
{{static_cast<bfloat16>(3.7f)}, {static_cast<bfloat16>(6.04f)}}}})
.get(),
Literal::CreateR1<bfloat16>(
{static_cast<bfloat16>(4), static_cast<bfloat16>(5)})
.get(),
Literal::CreateR1<bfloat16>(
{static_cast<bfloat16>(5), static_cast<bfloat16>(5)})
.get()});
ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01));
}
XLA_TEST_F(Bfloat16Test, BatchNormGrad) {
const int kFeatureIndex = 2;
ComputationBuilder builder(client_, TestName());
auto operand = builder.ConstantR4FromArray4D<bfloat16>(
Array4D<bfloat16>(2, 2, 2, 1, static_cast<bfloat16>(0.0f)));
auto scale = builder.ConstantR1<bfloat16>(
{static_cast<bfloat16>(1.0f), static_cast<bfloat16>(1.0f)});
auto mean = builder.ConstantR1<bfloat16>(
{static_cast<bfloat16>(0.0f), static_cast<bfloat16>(0.0f)});
auto var = builder.ConstantR1<bfloat16>(
{static_cast<bfloat16>(1.0f), static_cast<bfloat16>(1.0f)});
auto grad_output = builder.ConstantR4FromArray4D<bfloat16>(
{{{{static_cast<bfloat16>(1.f)}, {static_cast<bfloat16>(2.f)}},
{{static_cast<bfloat16>(3.f)}, {static_cast<bfloat16>(4.f)}}},
{{{static_cast<bfloat16>(5.f)}, {static_cast<bfloat16>(6.f)}},
{{static_cast<bfloat16>(7.f)}, {static_cast<bfloat16>(8.f)}}}});
builder.BatchNormGrad(operand, scale, mean, var, grad_output,
/*epsilon=*/0.0, kFeatureIndex);
auto expected = *Literal::MakeTuple(
{Literal::CreateR4<bfloat16>(
{{{{static_cast<bfloat16>(-3.f)}, {static_cast<bfloat16>(-3.f)}},
{{static_cast<bfloat16>(-1.f)}, {static_cast<bfloat16>(-1.f)}}},
{{{static_cast<bfloat16>(1.f)}, {static_cast<bfloat16>(1.f)}},
{{static_cast<bfloat16>(3.f)}, {static_cast<bfloat16>(3.f)}}}})
.get(),
Literal::CreateR1<bfloat16>(
{static_cast<bfloat16>(0), static_cast<bfloat16>(0)})
.get(),
Literal::CreateR1<bfloat16>(
{static_cast<bfloat16>(16), static_cast<bfloat16>(20)})
.get()});
ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01));
}
} // namespace
} // namespace xla

View File

@ -29,6 +29,7 @@ def xla_test(name,
deps,
xla_test_library_deps=[],
backends=[],
blacklisted_backends=[],
args=[],
tags=[],
copts=[],
@ -92,17 +93,24 @@ def xla_test(name,
backends: A list of backends to generate tests for. Supported
values: "cpu", "cpu_parallel", "gpu". If this list is empty, the test will
be generated for all supported backends.
blacklisted_backends: A list of backends to NOT generate tests for.
args: Test arguments for the target.
tags: Tags for the target.
backend_args: A dict mapping backend name to list of additional args to
use for that target.
copts: Additional copts to pass to the build.
data: Additional data to pass to the build.
backend_tags: A dict mapping backend name to list of additional tags to
use for that target.
backend_args: A dict mapping backend name to list of additional args to
use for that target.
**kwargs: Additional keyword arguments to pass to native.cc_test.
"""
test_names = []
if not backends:
backends = all_backends
backends = [backend for backend in backends
if backend not in blacklisted_backends]
native.cc_library(
name="%s_lib" % name,
srcs=srcs,

View File

@ -20,10 +20,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/test.h"
@ -42,26 +44,26 @@ TEST_F(ClientTest, ExecuteWithLayout) {
for (const std::vector<int64>& transfer_layout : layouts) {
b.Add(b.ConstantR2<int32>({{1, 2}, {3, 4}}),
b.ConstantR2<int32>({{10, 20}, {30, 40}}));
auto computation = b.Build();
ASSERT_TRUE(computation.ok()) << computation.status();
TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build());
ExecutionOptions execution_options = execution_options_;
*execution_options.mutable_shape_with_output_layout() =
ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2},
execute_layout);
std::unique_ptr<GlobalData> data =
client_->Execute(computation.ValueOrDie(), {}, &execution_options)
.ConsumeValueOrDie();
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> data,
client_->Execute(computation, {}, &execution_options));
std::unique_ptr<Literal> expected_literal =
Literal::CreateR2WithLayout<int32>(
{{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout));
auto computed = client_->Transfer(*data, &expected_literal->shape());
TF_ASSERT_OK_AND_ASSIGN(
auto computed, client_->Transfer(*data, &expected_literal->shape()));
LiteralTestUtil::AssertEqualShapesAndLayouts(
expected_literal->shape(), computed.ValueOrDie()->shape());
LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie());
LiteralTestUtil::AssertEqualShapesAndLayouts(expected_literal->shape(),
computed->shape());
LiteralTestUtil::ExpectEqual(*expected_literal, *computed);
}
}
}
@ -72,8 +74,7 @@ TEST_F(ClientTest, ExecuteWithTupleLayout) {
b.Tuple({b.ConstantR2<int32>({{1, 2}, {3, 4}}),
b.ConstantR2<int32>({{10, 20}, {30, 40}})});
auto computation = b.Build();
ASSERT_TRUE(computation.ok()) << computation.status();
TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build());
ExecutionOptions execution_options = execution_options_;
// Create a result shape with one element column major and the other row
@ -85,10 +86,9 @@ TEST_F(ClientTest, ExecuteWithTupleLayout) {
ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2},
/*minor_to_major=*/{1, 0})});
auto result =
client_
->ExecuteAndTransfer(computation.ValueOrDie(), {}, &execution_options)
.ConsumeValueOrDie();
TF_ASSERT_OK_AND_ASSIGN(
auto result,
client_->ExecuteAndTransfer(computation, {}, &execution_options));
LiteralTestUtil::ExpectR2Equal<int32>({{1, 2}, {3, 4}},
result->tuple_literals(0));
LiteralTestUtil::ExpectR2Equal<int32>({{10, 20}, {30, 40}},
@ -107,5 +107,42 @@ TEST_F(ClientTest, ExecuteWithTupleLayout) {
/*minor_to_major=*/{1, 0})));
}
TEST_F(ClientTest, DISABLED_ON_CPU_PARALLEL(DISABLED_ON_GPU(ExecuteParallel))) {
Computation add_with_one_arg, mul_with_two_args, dot_with_one_arg;
Shape shape = ShapeUtil::MakeShape(S32, {2, 2});
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> const_arg,
client_->TransferToServer(*Literal::CreateR2<int32>({{5, 6}, {7, 8}})));
ComputationBuilder b(client_, TestName() + ".add");
b.Add(b.Parameter(0, shape, "param_0"),
b.ConstantR2<int32>({{1, 2}, {3, 4}}));
TF_ASSERT_OK_AND_ASSIGN(add_with_one_arg, b.Build());
// We can't really test parallel execution on CPU since all of the cores in a
// CPU are presented as a single device. So for now we test "parallel"
// execution on a single device.
std::vector<Client::ComputationInstance> computation_instances;
TF_ASSERT_OK_AND_ASSIGN(std::vector<xla::DeviceHandle> devices,
client_->GetDeviceHandles(1));
ASSERT_EQ(devices.size(), 1);
ExecutionOptions options = execution_options_;
*options.add_device_handles() = devices[0];
computation_instances.push_back(Client::ComputationInstance(
add_with_one_arg, {const_arg.get()}, options, nullptr));
TF_ASSERT_OK_AND_ASSIGN(auto results,
client_->ExecuteParallel(computation_instances));
auto expected_result = Literal::CreateR2<int32>({{6, 8}, {10, 12}});
TF_ASSERT_OK_AND_ASSIGN(
auto result_literal,
client_->Transfer(*results[0], &expected_result->shape()));
LiteralTestUtil::ExpectEqual(*expected_result, *result_literal);
}
} // namespace
} // namespace xla

View File

@ -37,7 +37,7 @@ set_target_properties(lib_tf PROPERTIES IMPORTED_LOCATION
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DIS_SLIM_BUILD \
-std=c++11 -fno-rtti -fno-exceptions \
-O2 -Wno-narrowing -fomit-frame-pointer \
-mfpu=neon -mfloat-abi=softfp -fPIE \
-mfpu=neon -mfloat-abi=softfp -fPIE -fPIC \
-ftemplate-depth=900 \
-DGOOGLE_PROTOBUF_NO_RTTI \
-DGOOGLE_PROTOBUF_NO_STATIC_INITIALIZER")

View File

@ -16,7 +16,7 @@
#include <string>
#include <vector>
#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/node-stats.h"
#include "tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h"
#include "tensorflow/contrib/boosted_trees/proto/split_info.pb.h"
#include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h"
#include "tensorflow/core/framework/device_base.h"

View File

@ -408,7 +408,7 @@ tf_cc_test(
# Learner/stochastic
cc_library(
name = "gradient-stats",
hdrs = ["learner/stochastic/stats/gradient-stats.h"],
hdrs = ["learner/common/stats/gradient-stats.h"],
deps = [
"//tensorflow/core:framework_headers_lib",
"//third_party/eigen3",
@ -417,7 +417,7 @@ cc_library(
cc_library(
name = "node-stats",
hdrs = ["learner/stochastic/stats/node-stats.h"],
hdrs = ["learner/common/stats/node-stats.h"],
deps = [
":gradient-stats",
"//tensorflow/contrib/boosted_trees/proto:learner_proto_cc",
@ -429,7 +429,7 @@ cc_library(
cc_library(
name = "split-stats",
hdrs = ["learner/stochastic/stats/split-stats.h"],
hdrs = ["learner/common/stats/split-stats.h"],
deps = [
":node-stats",
],
@ -437,7 +437,7 @@ cc_library(
cc_library(
name = "feature-split-candidate",
hdrs = ["learner/stochastic/stats/feature-split-candidate.h"],
hdrs = ["learner/common/stats/feature-split-candidate.h"],
deps = [
":split-stats",
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc",
@ -447,7 +447,7 @@ cc_library(
tf_cc_test(
name = "node-stats_test",
size = "small",
srcs = ["learner/stochastic/stats/node-stats_test.cc"],
srcs = ["learner/common/stats/node-stats_test.cc"],
deps = [
":node-stats",
"//tensorflow/core:tensor_testutil",

View File

@ -13,10 +13,10 @@
// limitations under the License.
//
// =============================================================================
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_FEATURE_SPLIT_CANDIDATE_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_FEATURE_SPLIT_CANDIDATE_H_
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_FEATURE_SPLIT_CANDIDATE_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_FEATURE_SPLIT_CANDIDATE_H_
#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/split-stats.h"
#include "tensorflow/contrib/boosted_trees/lib/learner/common/stats/split-stats.h"
#include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h"
namespace tensorflow {
@ -58,4 +58,4 @@ struct FeatureSplitCandidate {
} // namespace boosted_trees
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_FEATURE_SPLIT_CANDIDATE_H_
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_FEATURE_SPLIT_CANDIDATE_H_

View File

@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_GRADIENT_STATS_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_GRADIENT_STATS_H_
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_GRADIENT_STATS_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_GRADIENT_STATS_H_
#include <math.h>
@ -190,4 +190,4 @@ inline GradientStats operator-(const GradientStats& a, const GradientStats& b) {
} // namespace boosted_trees
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_GRADIENT_STATS_H_
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_GRADIENT_STATS_H_

View File

@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_NODE_STATS_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_NODE_STATS_H_
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_NODE_STATS_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_NODE_STATS_H_
#include "third_party/eigen3/Eigen/Core"
#include "third_party/eigen3/Eigen/Eigenvalues"
#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/gradient-stats.h"
#include "tensorflow/contrib/boosted_trees/lib/learner/common/stats/gradient-stats.h"
#include "tensorflow/contrib/boosted_trees/proto/learner.pb.h"
#include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h"
#include "tensorflow/core/framework/shape_inference.h"
@ -298,4 +298,4 @@ struct NodeStats {
} // namespace boosted_trees
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_NODE_STATS_H_
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_NODE_STATS_H_

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/node-stats.h"
#include "tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/platform/test.h"

View File

@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_SPLIT_STATS_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_SPLIT_STATS_H_
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_SPLIT_STATS_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_SPLIT_STATS_H_
#include <string>
#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/node-stats.h"
#include "tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h"
namespace tensorflow {
namespace boosted_trees {
@ -81,4 +81,4 @@ struct SplitStats {
} // namespace boosted_trees
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_SPLIT_STATS_H_
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_SPLIT_STATS_H_

View File

@ -32,27 +32,41 @@ from tensorflow.python.platform import test
class CrfTest(test.TestCase):
def testCrfSequenceScore(self):
inputs = np.array(
[[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32)
tag_indices = np.array([1, 2, 1, 0], dtype=np.int32)
transition_params = np.array(
[[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
sequence_lengths = np.array(3, dtype=np.int32)
with self.test_session() as sess:
sequence_score = crf.crf_sequence_score(
inputs=array_ops.expand_dims(inputs, 0),
tag_indices=array_ops.expand_dims(tag_indices, 0),
sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
transition_params=constant_op.constant(transition_params))
sequence_score = array_ops.squeeze(sequence_score, [0])
tf_sequence_score = sess.run(sequence_score)
expected_unary_score = sum(inputs[i][tag_indices[i]]
for i in range(sequence_lengths))
expected_binary_score = sum(
transition_params[tag_indices[i], tag_indices[i + 1]]
for i in range(sequence_lengths - 1))
expected_sequence_score = expected_unary_score + expected_binary_score
self.assertAllClose(tf_sequence_score, expected_sequence_score)
# Test both the length-1 and regular cases.
sequence_lengths_list = [
np.array(3, dtype=np.int32),
np.array(1, dtype=np.int32)
]
inputs_list = [
np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]],
dtype=np.float32),
np.array([[4, 5, -3]],
dtype=np.float32),
]
tag_indices_list = [
np.array([1, 2, 1, 0], dtype=np.int32),
np.array([1], dtype=np.int32)
]
for sequence_lengths, inputs, tag_indices in zip(sequence_lengths_list,
inputs_list,
tag_indices_list):
with self.test_session() as sess:
sequence_score = crf.crf_sequence_score(
inputs=array_ops.expand_dims(inputs, 0),
tag_indices=array_ops.expand_dims(tag_indices, 0),
sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
transition_params=constant_op.constant(transition_params))
sequence_score = array_ops.squeeze(sequence_score, [0])
tf_sequence_score = sess.run(sequence_score)
expected_unary_score = sum(inputs[i][tag_indices[i]]
for i in range(sequence_lengths))
expected_binary_score = sum(
transition_params[tag_indices[i], tag_indices[i + 1]]
for i in range(sequence_lengths - 1))
expected_sequence_score = expected_unary_score + expected_binary_score
self.assertAllClose(tf_sequence_score, expected_sequence_score)
def testCrfUnaryScore(self):
inputs = np.array(
@ -89,38 +103,54 @@ class CrfTest(test.TestCase):
self.assertAllClose(tf_binary_score, expected_binary_score)
def testCrfLogNorm(self):
inputs = np.array(
[[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32)
transition_params = np.array(
[[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
num_words = inputs.shape[0]
num_tags = inputs.shape[1]
sequence_lengths = np.array(3, dtype=np.int32)
with self.test_session() as sess:
all_sequence_scores = []
# Test both the length-1 and regular cases.
sequence_lengths_list = [
np.array(3, dtype=np.int32),
np.array(1, dtype=np.int32)
]
inputs_list = [
np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]],
dtype=np.float32),
np.array([[3, -1, 3]],
dtype=np.float32),
]
tag_indices_list = [
np.array([1, 2, 1, 0], dtype=np.int32),
np.array([2], dtype=np.int32)
]
# Compare the dynamic program with brute force computation.
for tag_indices in itertools.product(
range(num_tags), repeat=sequence_lengths):
tag_indices = list(tag_indices)
tag_indices.extend([0] * (num_words - sequence_lengths))
all_sequence_scores.append(
crf.crf_sequence_score(
inputs=array_ops.expand_dims(inputs, 0),
tag_indices=array_ops.expand_dims(tag_indices, 0),
sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
transition_params=constant_op.constant(transition_params)))
for sequence_lengths, inputs, tag_indices in zip(sequence_lengths_list,
inputs_list,
tag_indices_list):
num_words = inputs.shape[0]
num_tags = inputs.shape[1]
with self.test_session() as sess:
all_sequence_scores = []
brute_force_log_norm = math_ops.reduce_logsumexp(all_sequence_scores)
log_norm = crf.crf_log_norm(
inputs=array_ops.expand_dims(inputs, 0),
sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
transition_params=constant_op.constant(transition_params))
log_norm = array_ops.squeeze(log_norm, [0])
tf_brute_force_log_norm, tf_log_norm = sess.run(
[brute_force_log_norm, log_norm])
# Compare the dynamic program with brute force computation.
for tag_indices in itertools.product(
range(num_tags), repeat=sequence_lengths):
tag_indices = list(tag_indices)
tag_indices.extend([0] * (num_words - sequence_lengths))
all_sequence_scores.append(
crf.crf_sequence_score(
inputs=array_ops.expand_dims(inputs, 0),
tag_indices=array_ops.expand_dims(tag_indices, 0),
sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
transition_params=constant_op.constant(transition_params)))
self.assertAllClose(tf_log_norm, tf_brute_force_log_norm)
brute_force_log_norm = math_ops.reduce_logsumexp(all_sequence_scores)
log_norm = crf.crf_log_norm(
inputs=array_ops.expand_dims(inputs, 0),
sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
transition_params=constant_op.constant(transition_params))
log_norm = array_ops.squeeze(log_norm, [0])
tf_brute_force_log_norm, tf_log_norm = sess.run(
[brute_force_log_norm, log_norm])
self.assertAllClose(tf_log_norm, tf_brute_force_log_norm)
def testCrfLogLikelihood(self):
inputs = np.array(
@ -201,50 +231,66 @@ class CrfTest(test.TestCase):
expected_max_sequence[:sequence_lengths])
def testCrfDecode(self):
inputs = np.array(
[[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32)
transition_params = np.array(
[[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
sequence_lengths = np.array(3, dtype=np.int32)
num_words = inputs.shape[0]
num_tags = inputs.shape[1]
# Test both the length-1 and regular cases.
sequence_lengths_list = [
np.array(3, dtype=np.int32),
np.array(1, dtype=np.int32)
]
inputs_list = [
np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]],
dtype=np.float32),
np.array([[-1, 2, 1]],
dtype=np.float32),
]
tag_indices_list = [
np.array([1, 2, 1, 0], dtype=np.int32),
np.array([2], dtype=np.int32)
]
with self.test_session() as sess:
all_sequence_scores = []
all_sequences = []
for sequence_lengths, inputs, tag_indices in zip(sequence_lengths_list,
inputs_list,
tag_indices_list):
num_words = inputs.shape[0]
num_tags = inputs.shape[1]
# Compare the dynamic program with brute force computation.
for tag_indices in itertools.product(
range(num_tags), repeat=sequence_lengths):
tag_indices = list(tag_indices)
tag_indices.extend([0] * (num_words - sequence_lengths))
all_sequences.append(tag_indices)
sequence_score = crf.crf_sequence_score(
inputs=array_ops.expand_dims(inputs, 0),
tag_indices=array_ops.expand_dims(tag_indices, 0),
sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
transition_params=constant_op.constant(transition_params))
sequence_score = array_ops.squeeze(sequence_score, [0])
all_sequence_scores.append(sequence_score)
with self.test_session() as sess:
all_sequence_scores = []
all_sequences = []
tf_all_sequence_scores = sess.run(all_sequence_scores)
# Compare the dynamic program with brute force computation.
for tag_indices in itertools.product(
range(num_tags), repeat=sequence_lengths):
tag_indices = list(tag_indices)
tag_indices.extend([0] * (num_words - sequence_lengths))
all_sequences.append(tag_indices)
sequence_score = crf.crf_sequence_score(
inputs=array_ops.expand_dims(inputs, 0),
tag_indices=array_ops.expand_dims(tag_indices, 0),
sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
transition_params=constant_op.constant(transition_params))
sequence_score = array_ops.squeeze(sequence_score, [0])
all_sequence_scores.append(sequence_score)
expected_max_sequence_index = np.argmax(tf_all_sequence_scores)
expected_max_sequence = all_sequences[expected_max_sequence_index]
expected_max_score = tf_all_sequence_scores[expected_max_sequence_index]
tf_all_sequence_scores = sess.run(all_sequence_scores)
actual_max_sequence, actual_max_score = crf.crf_decode(
array_ops.expand_dims(inputs, 0),
constant_op.constant(transition_params),
array_ops.expand_dims(sequence_lengths, 0))
actual_max_sequence = array_ops.squeeze(actual_max_sequence, [0])
actual_max_score = array_ops.squeeze(actual_max_score, [0])
tf_actual_max_sequence, tf_actual_max_score = sess.run(
[actual_max_sequence, actual_max_score])
expected_max_sequence_index = np.argmax(tf_all_sequence_scores)
expected_max_sequence = all_sequences[expected_max_sequence_index]
expected_max_score = tf_all_sequence_scores[expected_max_sequence_index]
self.assertAllClose(tf_actual_max_score, expected_max_score)
self.assertEqual(list(tf_actual_max_sequence[:sequence_lengths]),
expected_max_sequence[:sequence_lengths])
actual_max_sequence, actual_max_score = crf.crf_decode(
array_ops.expand_dims(inputs, 0),
constant_op.constant(transition_params),
array_ops.expand_dims(sequence_lengths, 0))
actual_max_sequence = array_ops.squeeze(actual_max_sequence, [0])
actual_max_score = array_ops.squeeze(actual_max_score, [0])
tf_actual_max_sequence, tf_actual_max_score = sess.run(
[actual_max_sequence, actual_max_score])
self.assertAllClose(tf_actual_max_score, expected_max_score)
self.assertEqual(list(tf_actual_max_sequence[:sequence_lengths]),
expected_max_sequence[:sequence_lengths])
if __name__ == "__main__":

View File

@ -53,7 +53,9 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.layers import utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn
@ -101,12 +103,29 @@ def crf_sequence_score(inputs, tag_indices, sequence_lengths,
Returns:
sequence_scores: A [batch_size] vector of unnormalized sequence scores.
"""
# Compute the scores of the given tag sequence.
unary_scores = crf_unary_score(tag_indices, sequence_lengths, inputs)
binary_scores = crf_binary_score(tag_indices, sequence_lengths,
transition_params)
sequence_scores = unary_scores + binary_scores
return sequence_scores
# If max_seq_len is 1, we skip the score calculation and simply gather the
# unary potentials of the single tag.
def _single_seq_fn():
batch_size = array_ops.shape(inputs, out_type=tag_indices.dtype)[0]
example_inds = array_ops.reshape(
math_ops.range(batch_size, dtype=tag_indices.dtype), [-1, 1])
return array_ops.gather_nd(
array_ops.squeeze(inputs, [1]),
array_ops.concat([example_inds, tag_indices], axis=1))
def _multi_seq_fn():
# Compute the scores of the given tag sequence.
unary_scores = crf_unary_score(tag_indices, sequence_lengths, inputs)
binary_scores = crf_binary_score(tag_indices, sequence_lengths,
transition_params)
sequence_scores = unary_scores + binary_scores
return sequence_scores
return utils.smart_cond(
pred=math_ops.equal(inputs.shape[1].value or array_ops.shape(inputs)[1],
1),
fn1=_single_seq_fn,
fn2=_multi_seq_fn)
def crf_log_norm(inputs, sequence_lengths, transition_params):
@ -124,19 +143,32 @@ def crf_log_norm(inputs, sequence_lengths, transition_params):
# algorithm.
first_input = array_ops.slice(inputs, [0, 0, 0], [-1, 1, -1])
first_input = array_ops.squeeze(first_input, [1])
rest_of_input = array_ops.slice(inputs, [0, 1, 0], [-1, -1, -1])
# Compute the alpha values in the forward algorithm in order to get the
# partition function.
forward_cell = CrfForwardRnnCell(transition_params)
_, alphas = rnn.dynamic_rnn(
cell=forward_cell,
inputs=rest_of_input,
sequence_length=sequence_lengths - 1,
initial_state=first_input,
dtype=dtypes.float32)
log_norm = math_ops.reduce_logsumexp(alphas, [1])
return log_norm
# If max_seq_len is 1, we skip the algorithm and simply reduce_logsumexp over
# the "initial state" (the unary potentials).
def _single_seq_fn():
return math_ops.reduce_logsumexp(first_input, [1])
def _multi_seq_fn():
"""Forward computation of alpha values."""
rest_of_input = array_ops.slice(inputs, [0, 1, 0], [-1, -1, -1])
# Compute the alpha values in the forward algorithm in order to get the
# partition function.
forward_cell = CrfForwardRnnCell(transition_params)
_, alphas = rnn.dynamic_rnn(
cell=forward_cell,
inputs=rest_of_input,
sequence_length=sequence_lengths - 1,
initial_state=first_input,
dtype=dtypes.float32)
log_norm = math_ops.reduce_logsumexp(alphas, [1])
return log_norm
max_seq_len = array_ops.shape(inputs)[1]
return control_flow_ops.cond(pred=math_ops.equal(max_seq_len, 1),
true_fn=_single_seq_fn,
false_fn=_multi_seq_fn)
def crf_log_likelihood(inputs,
@ -437,45 +469,64 @@ def crf_decode(potentials, transition_params, sequence_length):
sequence_length: A [batch_size] vector of true sequence lengths.
Returns:
decode_tags: A [batch_size, max_seq_len] tensor, with dtype tf.int32.
decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`.
Contains the highest scoring tag indices.
best_score: A [batch_size] tensor, containing the score of decode_tags.
best_score: A [batch_size] vector, containing the score of `decode_tags`.
"""
# For simplicity, in shape comments, denote:
# 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output).
num_tags = potentials.get_shape()[2].value
# If max_seq_len is 1, we skip the algorithm and simply return the argmax tag
# and the max activation.
def _single_seq_fn():
squeezed_potentials = array_ops.squeeze(potentials, [1])
decode_tags = array_ops.expand_dims(
math_ops.argmax(squeezed_potentials, axis=1), 1)
best_score = math_ops.reduce_max(squeezed_potentials, axis=1)
return math_ops.cast(decode_tags, dtype=dtypes.int32), best_score
# Computes forward decoding. Get last score and backpointers.
crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params)
initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1])
initial_state = array_ops.squeeze(initial_state, axis=[1]) # [B, O]
inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O]
backpointers, last_score = rnn.dynamic_rnn(
crf_fwd_cell,
inputs=inputs,
sequence_length=sequence_length - 1,
initial_state=initial_state,
time_major=False,
dtype=dtypes.int32) # [B, T - 1, O], [B, O]
backpointers = gen_array_ops.reverse_sequence(
backpointers, sequence_length - 1, seq_dim=1) # [B, T-1, O]
def _multi_seq_fn():
"""Decoding of highest scoring sequence."""
# Computes backward decoding. Extract tag indices from backpointers.
crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags)
initial_state = math_ops.cast(math_ops.argmax(last_score, axis=1),
dtype=dtypes.int32) # [B]
initial_state = array_ops.expand_dims(initial_state, axis=-1) # [B, 1]
decode_tags, _ = rnn.dynamic_rnn(
crf_bwd_cell,
inputs=backpointers,
sequence_length=sequence_length - 1,
initial_state=initial_state,
time_major=False,
dtype=dtypes.int32) # [B, T - 1, 1]
decode_tags = array_ops.squeeze(decode_tags, axis=[2]) # [B, T - 1]
decode_tags = array_ops.concat([initial_state, decode_tags], axis=1) # [B, T]
decode_tags = gen_array_ops.reverse_sequence(
decode_tags, sequence_length, seq_dim=1) # [B, T]
# For simplicity, in shape comments, denote:
# 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output).
num_tags = potentials.get_shape()[2].value
best_score = math_ops.reduce_max(last_score, axis=1) # [B]
return decode_tags, best_score
# Computes forward decoding. Get last score and backpointers.
crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params)
initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1])
initial_state = array_ops.squeeze(initial_state, axis=[1]) # [B, O]
inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O]
backpointers, last_score = rnn.dynamic_rnn( # [B, T - 1, O], [B, O]
crf_fwd_cell,
inputs=inputs,
sequence_length=sequence_length - 1,
initial_state=initial_state,
time_major=False,
dtype=dtypes.int32)
backpointers = gen_array_ops.reverse_sequence( # [B, T - 1, O]
backpointers, sequence_length - 1, seq_dim=1)
# Computes backward decoding. Extract tag indices from backpointers.
crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags)
initial_state = math_ops.cast(math_ops.argmax(last_score, axis=1), # [B]
dtype=dtypes.int32)
initial_state = array_ops.expand_dims(initial_state, axis=-1) # [B, 1]
decode_tags, _ = rnn.dynamic_rnn( # [B, T - 1, 1]
crf_bwd_cell,
inputs=backpointers,
sequence_length=sequence_length - 1,
initial_state=initial_state,
time_major=False,
dtype=dtypes.int32)
decode_tags = array_ops.squeeze(decode_tags, axis=[2]) # [B, T - 1]
decode_tags = array_ops.concat([initial_state, decode_tags], # [B, T]
axis=1)
decode_tags = gen_array_ops.reverse_sequence( # [B, T]
decode_tags, sequence_length, seq_dim=1)
best_score = math_ops.reduce_max(last_score, axis=1) # [B]
return decode_tags, best_score
return utils.smart_cond(
pred=math_ops.equal(
potentials.shape[1].value or array_ops.shape(potentials)[1], 1),
fn1=_single_seq_fn,
fn2=_multi_seq_fn)

View File

@ -187,6 +187,7 @@ py_test(
"manual", # b/67958761
],
deps = [
":dataset_serialization_test",
"//tensorflow/contrib/data/python/ops:dataset_ops",
"//tensorflow/contrib/data/python/ops:transformation_ops",
"//tensorflow/python:array_ops",

View File

@ -723,5 +723,41 @@ class BatchDatasetSerializationTest(
num_outputs)
class PaddedBatchDatasetSerializationTest(
dataset_serialization_test_base.DatasetSerializationTestBase):
def testPaddedBatch(self):
def build_dataset(seq_lens):
return dataset_ops.Dataset.from_tensor_slices(seq_lens).map(
lambda x: array_ops.fill([x], x)).padded_batch(
4, padded_shapes=[-1])
seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32)
seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32)
self.run_core_tests(lambda: build_dataset(seq_lens1),
lambda: build_dataset(seq_lens2), 8)
def testPaddedBatchNonDefaultPadding(self):
def build_dataset(seq_lens):
def fill_tuple(x):
filled = array_ops.fill([x], x)
return (filled, string_ops.as_string(filled))
padded_shape = [-1]
return dataset_ops.Dataset.from_tensor_slices(seq_lens).map(
fill_tuple).padded_batch(
4,
padded_shapes=(padded_shape, padded_shape),
padding_values=(-1, "<end>"))
seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32)
seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32)
self.run_core_tests(lambda: build_dataset(seq_lens1),
lambda: build_dataset(seq_lens2), 8)
if __name__ == "__main__":
test.main()

View File

@ -22,8 +22,10 @@ import math
import threading
import time
import numpy as np
from six.moves import zip_longest
from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
from tensorflow.contrib.data.python.ops import dataset_ops
from tensorflow.contrib.data.python.ops import interleave_ops
from tensorflow.python.framework import dtypes
@ -209,6 +211,46 @@ class InterleaveDatasetTest(test.TestCase):
sess.run(get_next)
class InterleaveDatasetSeriazationTest(
dataset_serialization_test_base.DatasetSerializationTestBase):
def _build_iterator_graph(self, input_values, cycle_length, block_length):
repeat_count = 2
return dataset_ops.Dataset.from_tensor_slices(input_values).repeat(
repeat_count).interleave(
lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x),
cycle_length, block_length)
def testSerializationCore(self):
input_values = np.array([4, 5, 6], dtype=np.int64)
num_outputs = np.sum(input_values) * 2
# cycle_length > 1, block_length > 1
cycle_length = 2
block_length = 3
# pylint: disable=g-long-lambda
self.run_core_tests(
lambda: self._build_iterator_graph(
input_values, cycle_length, block_length),
lambda: self._build_iterator_graph(
input_values, cycle_length * 2, block_length * 1),
num_outputs)
# cycle_length = 1
cycle_length = 1
block_length = 3
self.run_core_tests(
lambda: self._build_iterator_graph(
input_values, cycle_length, block_length),
None, num_outputs)
# block_length = 1
cycle_length = 2
block_length = 1
self.run_core_tests(
lambda: self._build_iterator_graph(
input_values, cycle_length, block_length),
None, num_outputs)
# pylint: enable=g-long-lambda
class ParallelInterleaveDatasetTest(test.TestCase):
def setUp(self):

View File

@ -41,6 +41,7 @@ def try_import(name): # pylint: disable=invalid-name
tf_logging.warning("Could not import %s: %s" % (name, str(e)))
return module
stats = try_import("scipy.stats")
@ -62,9 +63,9 @@ class CauchyTest(test.TestCase):
self.assertAllEqual(expected, scale_shape.eval())
loc = array_ops.zeros(loc_shape)
scale = array_ops.ones(scale_shape)
self.assertAllEqual(
expected,
array_ops.shape(cauchy_lib.Cauchy(loc, scale).sample()).eval())
self.assertAllEqual(expected,
array_ops.shape(
cauchy_lib.Cauchy(loc, scale).sample()).eval())
def _testParamStaticShapes(self, sample_shape, expected):
param_shapes = cauchy_lib.Cauchy.param_static_shapes(sample_shape)
@ -92,8 +93,7 @@ class CauchyTest(test.TestCase):
cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale)
log_pdf = cauchy.log_prob(x)
self.assertAllEqual(cauchy.batch_shape_tensor().eval(),
log_pdf.shape)
self.assertAllEqual(cauchy.batch_shape_tensor().eval(), log_pdf.shape)
self.assertAllEqual(cauchy.batch_shape_tensor().eval(),
log_pdf.eval().shape)
self.assertAllEqual(cauchy.batch_shape, log_pdf.shape)
@ -115,16 +115,15 @@ class CauchyTest(test.TestCase):
with self.test_session():
batch_size = 6
loc = constant_op.constant([[3.0, -3.0]] * batch_size)
scale = constant_op.constant([[np.sqrt(10.0), np.sqrt(15.0)]] *
batch_size)
scale = constant_op.constant(
[[np.sqrt(10.0), np.sqrt(15.0)]] * batch_size)
x = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T
cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale)
log_pdf = cauchy.log_prob(x)
log_pdf_values = log_pdf.eval()
self.assertEqual(log_pdf.shape, (6, 2))
self.assertAllEqual(cauchy.batch_shape_tensor().eval(),
log_pdf.shape)
self.assertAllEqual(cauchy.batch_shape_tensor().eval(), log_pdf.shape)
self.assertAllEqual(cauchy.batch_shape_tensor().eval(),
log_pdf.eval().shape)
self.assertAllEqual(cauchy.batch_shape, log_pdf.shape)
@ -248,8 +247,7 @@ class CauchyTest(test.TestCase):
cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale)
entropy = cauchy.entropy()
self.assertAllEqual(cauchy.batch_shape_tensor().eval(),
entropy.shape)
self.assertAllEqual(cauchy.batch_shape_tensor().eval(), entropy.shape)
self.assertAllEqual(cauchy.batch_shape_tensor().eval(),
entropy.eval().shape)
self.assertAllEqual(cauchy.batch_shape, entropy.shape)
@ -257,7 +255,7 @@ class CauchyTest(test.TestCase):
if not stats:
return
expected_entropy = stats.cauchy(loc, scale).entropy()
expected_entropy = stats.cauchy(loc, scale[0]).entropy().reshape((1, 3))
self.assertAllClose(expected_entropy, entropy.eval())
def testCauchyMode(self):
@ -368,8 +366,8 @@ class CauchyTest(test.TestCase):
self.assertAllEqual(expected_shape, samples.shape)
self.assertAllEqual(expected_shape, sample_values.shape)
expected_shape = (tensor_shape.TensorShape(
[n.eval()]).concatenate(cauchy.batch_shape))
expected_shape = (
tensor_shape.TensorShape([n.eval()]).concatenate(cauchy.batch_shape))
self.assertAllEqual(expected_shape, samples.shape)
self.assertAllEqual(expected_shape, sample_values.shape)
@ -385,18 +383,18 @@ class CauchyTest(test.TestCase):
samples = cauchy.sample(n)
sample_values = samples.eval()
self.assertEqual(samples.shape, (100000, batch_size, 2))
self.assertAllClose(np.median(sample_values[:, 0, 0]),
loc_v[0], atol=1e-1)
self.assertAllClose(np.median(sample_values[:, 0, 1]),
loc_v[1], atol=1e-1)
self.assertAllClose(
np.median(sample_values[:, 0, 0]), loc_v[0], atol=1e-1)
self.assertAllClose(
np.median(sample_values[:, 0, 1]), loc_v[1], atol=1e-1)
expected_shape = tensor_shape.TensorShape([n.eval()]).concatenate(
tensor_shape.TensorShape(cauchy.batch_shape_tensor().eval()))
self.assertAllEqual(expected_shape, samples.shape)
self.assertAllEqual(expected_shape, sample_values.shape)
expected_shape = (tensor_shape.TensorShape(
[n.eval()]).concatenate(cauchy.batch_shape))
expected_shape = (
tensor_shape.TensorShape([n.eval()]).concatenate(cauchy.batch_shape))
self.assertAllEqual(expected_shape, samples.shape)
self.assertAllEqual(expected_shape, sample_values.shape)
@ -428,9 +426,12 @@ class CauchyTest(test.TestCase):
self.assertEqual(cauchy.event_shape, ())
self.assertAllEqual(cauchy.event_shape_tensor().eval(), [])
self.assertAllEqual(
sess.run(cauchy.batch_shape_tensor(),
feed_dict={loc: 5.0,
scale: [1.0, 2.0]}), [2])
sess.run(
cauchy.batch_shape_tensor(),
feed_dict={
loc: 5.0,
scale: [1.0, 2.0]
}), [2])
if __name__ == "__main__":

View File

@ -30,7 +30,6 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
__all__ = [
"Cauchy",
]
@ -97,7 +96,7 @@ class Cauchy(distribution.Distribution):
validate_args=False,
allow_nan_stats=True,
name="Cauchy"):
"""Construct Cauchy distributions with loc and and scale `loc` and `scale`.
"""Construct Cauchy distributions.
The parameters `loc` and `scale` must be shaped in a way that supports
broadcasting (e.g. `loc + scale` is a valid operation).
@ -121,8 +120,8 @@ class Cauchy(distribution.Distribution):
"""
parameters = locals()
with ops.name_scope(name, values=[loc, scale]):
with ops.control_dependencies([check_ops.assert_positive(scale)] if
validate_args else []):
with ops.control_dependencies([check_ops.assert_positive(scale)]
if validate_args else []):
self._loc = array_ops.identity(loc, name="loc")
self._scale = array_ops.identity(scale, name="scale")
check_ops.assert_same_float_dtype([self._loc, self._scale])
@ -138,8 +137,8 @@ class Cauchy(distribution.Distribution):
@staticmethod
def _param_shapes(sample_shape):
return dict(
zip(("loc", "scale"), ([ops.convert_to_tensor(
sample_shape, dtype=dtypes.int32)] * 2)))
zip(("loc", "scale"),
([ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)] * 2)))
@property
def loc(self):
@ -153,13 +152,10 @@ class Cauchy(distribution.Distribution):
def _batch_shape_tensor(self):
return array_ops.broadcast_dynamic_shape(
array_ops.shape(self.loc),
array_ops.shape(self.scale))
array_ops.shape(self.loc), array_ops.shape(self.scale))
def _batch_shape(self):
return array_ops.broadcast_static_shape(
self.loc.shape,
self.scale.shape)
return array_ops.broadcast_static_shape(self.loc.shape, self.scale.shape)
def _event_shape_tensor(self):
return constant_op.constant([], dtype=dtypes.int32)

View File

@ -116,6 +116,7 @@ py_library(
deps = [
":clip_weights",
":conditioning_utils",
":tensor_pool",
":virtual_batchnorm",
"//tensorflow/python:util",
],
@ -219,6 +220,37 @@ py_test(
],
)
py_library(
name = "tensor_pool",
srcs = [
"python/features/python/tensor_pool.py",
"python/features/python/tensor_pool_impl.py",
],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:data_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:random_ops",
"//tensorflow/python:util",
],
)
py_test(
name = "tensor_pool_test",
srcs = ["python/features/python/tensor_pool_test.py"],
srcs_version = "PY2AND3",
deps = [
":tensor_pool",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//third_party/py/numpy",
],
)
py_library(
name = "virtual_batchnorm",
srcs = [

View File

@ -0,0 +1,35 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A tensor pool stores values from an input tensor and returns a stored one.
See the following papers for more details.
1) `Learning from simulated and unsupervised images through adversarial
training` (https://arxiv.org/abs/1612.07828).
2) `Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial
Networks` (https://arxiv.org/abs/1703.10593).
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.gan.python.features.python import tensor_pool_impl
# pylint: disable=wildcard-import
from tensorflow.contrib.gan.python.features.python.tensor_pool_impl import *
# pylint: enable=wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
__all__ = tensor_pool_impl.__all__
remove_undocumented(__name__, __all__)

View File

@ -0,0 +1,118 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A tensor pool stores values from an input tensor and returns a stored one.
We use this to keep a history of values created by a generator, such that
a discriminator can randomly be trained on some older samples, not just the
current one. This can help to not let the discriminator get too far ahead of the
generator and also to keep the system from oscilating, if the discriminator
forgets too fast what past samples from the generator looked like.
See the following papers for more details.
1) `Learning from simulated and unsupervised images through adversarial
training` (https://arxiv.org/abs/1612.07828).
2) `Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial
Networks` (https://arxiv.org/abs/1703.10593).
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import random_ops
__all__ = [
'tensor_pool',
]
def tensor_pool(input_value,
pool_size,
pooling_probability=0.5,
name='tensor_pool'):
"""Queue storing input values and returning random previously stored ones.
Every time the returned `output_value` is evaluated, `input_value` is
evaluated and its value either directly returned (with
`1-pooling_probability`) or stored in the pool and a random one of the samples
currently in the pool is popped and returned. As long as the pool in not fully
filled, the input_value is always directly returned, as well as stored in the
pool. Note during inference / testing, it may be appropriate to set
`pool_size` = 0 or `pooling_probability` = 0.
Args:
input_value: A `Tensor` from which to read values to be pooled.
pool_size: An integer specifying the maximum size of the pool.
pooling_probability: A float `Tensor` specifying the probability of getting
a value from the pool, as opposed to just the current input.
name: A string prefix for the name scope for all tensorflow ops.
Returns:
A `Tensor` which is with given probability either the `input_value` or a
randomly chosen sample that was previously inserted in the pool.
Raises:
ValueError: If `pool_size` is negative.
"""
pool_size = int(pool_size)
if pool_size < 0:
raise ValueError('`pool_size` is negative.')
elif pool_size == 0:
return input_value
with ops.name_scope('{}_pool_queue'.format(name),
values=[input_value, pooling_probability]):
pool_queue = data_flow_ops.RandomShuffleQueue(
capacity=pool_size,
min_after_dequeue=0,
dtypes=[input_value.dtype],
shapes=None)
# In pseudeo code this code does the following:
# if not pool_full:
# enqueue(input_value)
# return input_value
# else
# dequeue_value = dequeue_random_sample()
# enqueue(input_value)
# if rand() < pooling_probability:
# return dequeue_value
# else
# return input_value
def _get_input_value_pooled():
enqueue_op = pool_queue.enqueue(input_value)
with ops.control_dependencies([enqueue_op]):
return array_ops.identity(input_value)
def _get_random_pool_value_and_enqueue_input():
dequeue_value = pool_queue.dequeue()
with ops.control_dependencies([dequeue_value]):
enqueue_op = pool_queue.enqueue(input_value)
with ops.control_dependencies([enqueue_op]):
prob = random_ops.random_uniform(
(), dtype=dtypes.float32) < pooling_probability
return control_flow_ops.cond(prob, lambda: dequeue_value,
lambda: input_value)
output_value = control_flow_ops.cond(
pool_queue.size() < pool_size, _get_input_value_pooled,
_get_random_pool_value_and_enqueue_input)
return output_value

View File

@ -0,0 +1,94 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tf.contrib.gan.python.features.tensor_pool."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.contrib.gan.python.features.python import tensor_pool_impl as tensor_pool
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class TensorPoolTest(test.TestCase):
def test_pool_unknown_input_shape(self):
"""Checks that `input_value` can have unknown shape."""
input_value = array_ops.placeholder(
dtype=dtypes.int32, shape=[None, None, 3])
output_value = tensor_pool.tensor_pool(input_value, pool_size=10)
with self.test_session(use_gpu=True) as session:
for i in range(10):
session.run(output_value, {input_value: [[[i] * 3]]})
session.run(output_value, {input_value: [[[i] * 3] * 2]})
session.run(output_value, {input_value: [[[i] * 3] * 5] * 2})
def test_pool_sequence(self):
"""Checks that values are pooled and returned maximally twice."""
input_value = array_ops.placeholder(dtype=dtypes.int32, shape=[])
output_value = tensor_pool.tensor_pool(input_value, pool_size=10)
with self.test_session(use_gpu=True) as session:
outs = []
for i in range(50):
out = session.run(output_value, {input_value: i})
outs.append(out)
self.assertLessEqual(out, i)
_, counts = np.unique(outs, return_counts=True)
# Check that each value is returned maximally twice.
self.assertTrue((counts <= 2).all())
def test_never_pool(self):
"""Checks that setting `pooling_probability` to zero works."""
input_value = array_ops.placeholder(dtype=dtypes.int32, shape=[])
output_value = tensor_pool.tensor_pool(
input_value, pool_size=10, pooling_probability=0.0)
with self.test_session(use_gpu=True) as session:
for i in range(50):
out = session.run(output_value, {input_value: i})
self.assertEqual(out, i)
def test_pooling_probability(self):
"""Checks that `pooling_probability` works."""
input_value = array_ops.placeholder(dtype=dtypes.int32, shape=[])
pool_size = 10
pooling_probability = 0.2
output_value = tensor_pool.tensor_pool(
input_value,
pool_size=pool_size,
pooling_probability=pooling_probability)
with self.test_session(use_gpu=True) as session:
not_pooled = 0
total = 1000
for i in range(total):
out = session.run(output_value, {input_value: i})
if out == i:
not_pooled += 1
self.assertAllClose(
(not_pooled - pool_size) / (total - pool_size),
1 - pooling_probability,
atol=0.03)
if __name__ == '__main__':
test.main()

View File

@ -40,6 +40,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
cov_ema_decay,
damping,
layer_collection,
var_list=None,
momentum=0.,
momentum_type="regular",
norm_constraint=None,
@ -66,6 +67,9 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
blocks, kronecker factors, and losses associated with the
graph. The layer_collection cannot be modified after KfacOptimizer's
initialization.
var_list: Optional list or tuple of variables to train. Defaults to the
list of variables collected in the graph under the key
`GraphKeys.TRAINABLE_VARIABLES`.
momentum: The momentum value for this optimizer. Only applies when
momentum_type is 'regular' or 'adam'. (Default: 0)
momentum_type: The type of momentum to use in this optimizer, one of
@ -96,9 +100,9 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
or 'adam'.
"""
# We may consider determining the set of variables some other way, but for
# now it's just all the trainable variables.
variables = tf_variables.trainable_variables()
variables = var_list
if variables is None:
variables = tf_variables.trainable_variables()
self._fisher_est = est.FisherEstimator(
variables,
@ -123,7 +127,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
raise ValueError("Momentum must be unspecified if using a momentum_type "
"other than 'regular' or 'adam'.")
self._momentum = ops.convert_to_tensor(momentum, name="momentum")
self._momentum = momentum
self._momentum_type = momentum_type
self._norm_constraint = norm_constraint
@ -313,14 +317,17 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
self._batch_size, dtype=fft_precon_grads[0].dtype)
# compute the entries of the 2x2 matrix
m_11 = (_inner_product_list(fft_precon_grads, fft_precon_grads) / batch_size
+ self.damping * _inner_product_list(precon_grads, precon_grads))
m_11 = (
_inner_product_list(fft_precon_grads, fft_precon_grads) / batch_size +
self.damping * _inner_product_list(precon_grads, precon_grads))
m_21 = (_inner_product_list(fft_prev_updates, fft_precon_grads) / batch_size
+ self.damping * _inner_product_list(prev_updates, precon_grads))
m_21 = (
_inner_product_list(fft_prev_updates, fft_precon_grads) / batch_size +
self.damping * _inner_product_list(prev_updates, precon_grads))
m_22 = (_inner_product_list(fft_prev_updates, fft_prev_updates) / batch_size
+ self.damping * _inner_product_list(prev_updates, prev_updates))
m_22 = (
_inner_product_list(fft_prev_updates, fft_prev_updates) / batch_size +
self.damping * _inner_product_list(prev_updates, prev_updates))
def non_zero_prevupd_case():
r"""Computes optimal (alpha, mu) given non-zero previous update.
@ -406,8 +413,8 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
grads = list(grad for (grad, _) in grads_and_vars)
variables = list(var for (_, var) in grads_and_vars)
# previous updates are the negative velocities (up to scaling by LR)
prev_updates = list(-self._zeros_slot(var, "velocity", self._name)
for var in variables)
prev_updates = list(
-self._zeros_slot(var, "velocity", self._name) for var in variables)
# Compute optimal velocity update parameters according to quadratic model
alpha, mu, _ = self._compute_qmodel_hyperparams(

View File

@ -28,7 +28,6 @@ from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
# Method used for inverting matrices.
POSDEF_INV_METHOD = "cholesky"
@ -202,9 +201,18 @@ def posdef_inv_cholesky(tensor, identity, damping):
return linalg_ops.cholesky_solve(chol, identity)
def posdef_inv_eig(tensor, identity, damping):
"""Computes inverse(tensor + damping * identity) with eigendecomposition."""
eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(
tensor + damping * identity)
return math_ops.matmul(
eigenvectors / eigenvalues, eigenvectors, transpose_b=True)
posdef_inv_funcs = {
"matrix_inverse": posdef_inv_matrix_inverse,
"cholesky": posdef_inv_cholesky,
"eig": posdef_inv_eig,
}
@ -261,8 +269,8 @@ def fwd_gradients(ys, xs, grad_xs=None, stop_gradients=None):
# generated by the first gradients_impl.gradients call.
us = [array_ops.zeros_like(y) + float("nan") for y in ys]
dydxs = gradients_impl.gradients(ys, xs, grad_ys=us,
stop_gradients=stop_gradients)
dydxs = gradients_impl.gradients(
ys, xs, grad_ys=us, stop_gradients=stop_gradients)
# Deal with strange types that gradients_impl.gradients returns but can't
# deal with.
@ -278,3 +286,6 @@ def fwd_gradients(ys, xs, grad_xs=None, stop_gradients=None):
dysdx = gradients_impl.gradients(dydxs, us, grad_ys=grad_xs)
return dysdx
# TODO(b/69623235): Add a function for finding tensors that share gradients
# to eliminate redundant fisher factor computations.

View File

@ -309,7 +309,6 @@ def _fused_batch_norm(inputs,
new_shape = [-1, channels, 1, 1]
inputs = array_ops.reshape(inputs, new_shape)
inputs_shape = inputs.get_shape()
dtype = inputs.dtype.base_dtype
if data_format == DATA_FORMAT_NHWC:
params_shape = inputs_shape[-1:]
else:

View File

@ -1779,7 +1779,8 @@ class BatchNormTest(test.TestCase):
dtype = dtypes.float32
height, width = 3, 3
with self.test_session():
images = np.random.uniform(size=(5, height, width, 3)).astype(dtype.as_numpy_dtype)
images = np.random.uniform(size=(5, height, width, 3)).astype(
dtype.as_numpy_dtype)
output = _layers.batch_norm(images, fused=fused)
expected_name = ('BatchNorm/FusedBatchNorm' if fused else
'BatchNorm/batchnorm')
@ -2665,18 +2666,18 @@ class BatchNormTest(test.TestCase):
# Test case for 11673
with self.test_session() as sess:
a_32 = array_ops.placeholder(dtypes.float32, shape=(10, 10, 10, 10))
b_32 = _layers.batch_norm(a_32, center=False, data_format='NCHW',
zero_debias_moving_mean=True)
_layers.batch_norm(
a_32, center=False, data_format='NCHW', zero_debias_moving_mean=True)
a_16 = array_ops.placeholder(dtypes.float16, shape=(10, 10, 10, 10))
b_16 = _layers.batch_norm(a_16, center=False, data_format='NCHW',
zero_debias_moving_mean=True)
_layers.batch_norm(
a_16, center=False, data_format='NCHW', zero_debias_moving_mean=True)
sess.run(variables_lib.global_variables_initializer())
def testVariablesAreFloat32(self):
height, width = 3, 3
with self.test_session():
images = random_ops.random_uniform((5, height, width, 3),
seed=1, dtype=dtypes.float16)
images = random_ops.random_uniform(
(5, height, width, 3), seed=1, dtype=dtypes.float16)
_layers.batch_norm(images, scale=True)
beta = variables.get_variables_by_name('beta')[0]
gamma = variables.get_variables_by_name('gamma')[0]
@ -2691,17 +2692,13 @@ class BatchNormTest(test.TestCase):
channels = shape[1]
images = np.arange(np.product(shape), dtype=dtype).reshape(shape)
beta = init_ops.constant_initializer(
np.arange(
2, channels + 2, dtype=np.float32))
np.arange(2, channels + 2, dtype=np.float32))
gamma = init_ops.constant_initializer(
np.arange(
10, channels + 10, dtype=np.float32) * 2.0)
np.arange(10, channels + 10, dtype=np.float32) * 2.0)
mean = init_ops.constant_initializer(
np.arange(
3, channels + 3, dtype=np.float32) * 5.0)
np.arange(3, channels + 3, dtype=np.float32) * 5.0)
variance = init_ops.constant_initializer(
np.arange(
1, channels + 1, dtype=np.float32) * 4.0)
np.arange(1, channels + 1, dtype=np.float32) * 4.0)
output = _layers.batch_norm(
images,
fused=True,
@ -2726,7 +2723,6 @@ class BatchNormTest(test.TestCase):
res_16 = self._runFusedBatchNorm(shape, np.float16)
self.assertAllClose(res_32, res_16, rtol=1e-3)
def testAdjustmentCreated(self):
# Tests that the adjustment is appropriately passed to and used by the core
# BN layer.

View File

@ -28,7 +28,6 @@ import six
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import tf_logging as logging
@ -369,10 +368,11 @@ class DataFeeder(object):
if x_is_dict:
num_samples = list(self._x.values())[0].shape[0]
elif tensor_util.is_tensor(self._x):
num_samples = self._x.shape[0].value # shape will be a Dimension, extract an int
num_samples = self._x.shape[
0].value # shape will be a Dimension, extract an int
else:
num_samples = self._x.shape[0]
if self._shuffle:
self.indices = self.random_state.permutation(num_samples)
else:

View File

@ -251,8 +251,9 @@ class SdcaModel(object):
result_dense = 0.0
for i in range(len(dense_variables)):
result_dense += math_ops.matmul(
dense_features[i], array_ops.expand_dims(dense_variables[i], -1))
result_dense += math_ops.matmul(dense_features[i],
array_ops.expand_dims(
dense_variables[i], -1))
# Reshaping to allow shape inference at graph construction time.
return array_ops.reshape(result_dense, [-1]) + result_sparse

View File

@ -164,8 +164,8 @@ def toco_convert(input_data,
toco = _toco_flags_pb2.TocoFlags()
toco.input_format = input_format
toco.output_format = output_format
toco.drop_control_dependency = drop_control_dependency
model = _model_flags_pb2.ModelFlags()
model.drop_control_dependency = drop_control_dependency
toco.inference_type = inference_type
for idx, input_tensor in enumerate(input_tensors):
if input_tensor.dtype == _dtypes.float32:

View File

@ -40,6 +40,7 @@ from six import StringIO
# TODO(aselle): Disable GPU for now
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
# pylint: disable=g-import-not-at-top
import tensorflow as tf
from google.protobuf import text_format
# TODO(aselle): switch to TensorFlow's resource_loader
@ -383,7 +384,7 @@ def make_zip_of_tests(zip_path,
report["toco_log"] = ""
tf.reset_default_graph()
with tf.device('/cpu:0'):
with tf.device("/cpu:0"):
try:
inputs, outputs = make_graph(param_dict_real)
except (tf.errors.UnimplementedError, tf.errors.InvalidArgumentError,

View File

@ -194,7 +194,6 @@ struct ParsedModelFlags {
Arg<string> input_data_type;
Arg<string> input_data_types;
Arg<bool> variable_batch = Arg<bool>(false);
Arg<bool> drop_control_dependency = Arg<bool>(false);
Arg<toco::IntList> input_shape;
Arg<toco::StringMapList> rnn_states;
Arg<toco::StringMapList> model_checks;
@ -224,6 +223,7 @@ struct ParsedTocoFlags {
// Deprecated flags
Arg<string> input_type;
Arg<string> input_types;
Arg<bool> drop_control_dependency = Arg<bool>(false);
};
} // namespace toco

View File

@ -35,8 +35,11 @@ limitations under the License.
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/logging.h"
using tensorflow::DT_BOOL;
using tensorflow::DT_FLOAT;
using tensorflow::DT_INT32;
using tensorflow::DT_INT64;
using tensorflow::DT_UINT8;
using tensorflow::GraphDef;
using tensorflow::TensorProto;
@ -1500,10 +1503,29 @@ void ConvertOperator(const Model& model, const Operator& src_op,
}
}
void AddPlaceholder(const string& name, GraphDef* tensorflow_graph) {
void AddPlaceholder(const string& name, ArrayDataType type,
GraphDef* tensorflow_graph) {
auto* placeholder = tensorflow_graph->add_node();
placeholder->set_op("Placeholder");
(*placeholder->mutable_attr())["dtype"].set_type(DT_FLOAT);
switch (type) {
case ArrayDataType::kBool:
(*placeholder->mutable_attr())["dtype"].set_type(DT_BOOL);
break;
case ArrayDataType::kFloat:
(*placeholder->mutable_attr())["dtype"].set_type(DT_FLOAT);
break;
case ArrayDataType::kUint8:
(*placeholder->mutable_attr())["dtype"].set_type(DT_UINT8);
break;
case ArrayDataType::kInt32:
(*placeholder->mutable_attr())["dtype"].set_type(DT_INT32);
break;
case ArrayDataType::kInt64:
(*placeholder->mutable_attr())["dtype"].set_type(DT_INT64);
break;
default:
LOG(FATAL) << "Unexpected data type in array \"" << name << "\"";
}
placeholder->set_name(name);
}
@ -1531,7 +1553,9 @@ void AddPlaceholderForRNNState(const Model& model, const string& name, int size,
void ExportTensorFlowGraphDefImplementation(const Model& model,
GraphDef* tensorflow_graph) {
for (const auto& input_array : model.flags.input_arrays()) {
AddPlaceholder(input_array.name(), tensorflow_graph);
AddPlaceholder(input_array.name(),
model.arrays.at(input_array.name())->data_type,
tensorflow_graph);
}
for (const auto& rnn_state : model.flags.rnn_states()) {
AddPlaceholderForRNNState(model, rnn_state.state_array(), rnn_state.size(),

File diff suppressed because it is too large Load Diff

View File

@ -23,11 +23,19 @@ limitations under the License.
namespace toco {
std::unique_ptr<Model> ImportTensorFlowGraphDef(
const ModelFlags& model_flags, const tensorflow::GraphDef& graph_def);
struct TensorFlowImportFlags {
// If true, control dependencies will be dropped immediately
// during the import of the TensorFlow GraphDef.
bool drop_control_dependency = false;
};
std::unique_ptr<Model> ImportTensorFlowGraphDef(
const ModelFlags& model_flags, const string& input_file_contents);
const ModelFlags& model_flags, const TensorFlowImportFlags& tf_import_flags,
const tensorflow::GraphDef& graph_def);
std::unique_ptr<Model> ImportTensorFlowGraphDef(
const ModelFlags& model_flags, const TensorFlowImportFlags& tf_import_flags,
const string& input_file_contents);
} // namespace toco

View File

@ -112,13 +112,6 @@ bool ParseModelFlagsFromCommandLineFlags(
"exclusive "
"with the 'batch' field: at most one of these two fields can be "
"set."),
Flag(
"drop_control_dependency",
parsed_flags.drop_control_dependency.bind(),
parsed_flags.drop_control_dependency.default_value(),
"If true, ignore control dependency requirements in input TensorFlow "
"GraphDef. Otherwise an error will be raised upon control dependency "
"inputs."),
Flag("rnn_states", parsed_flags.rnn_states.bind(),
parsed_flags.rnn_states.default_value(), ""),
Flag("model_checks", parsed_flags.model_checks.bind(),
@ -316,7 +309,6 @@ void ReadModelFlagsFromCommandLineFlags(
} while (false)
READ_MODEL_FLAG(variable_batch);
READ_MODEL_FLAG(drop_control_dependency);
#undef READ_MODEL_FLAG

View File

@ -138,8 +138,4 @@ message ModelFlags {
optional int32 count_max = 3 [default = -1];
}
repeated ModelCheck model_checks = 14;
// If true, ignore control dependency requirements in input TensorFlow
// GraphDef. Otherwise an error will be raised upon control dependency inputs.
optional bool drop_control_dependency = 15;
}

View File

@ -103,6 +103,13 @@ bool ParseTocoFlagsFromCommandLineFlags(
parsed_flags.allow_custom_ops.default_value(),
"If true, allow TOCO to create TF Lite Custom operators for all the"
"unsupported Tensorflow ops."),
Flag(
"drop_control_dependency",
parsed_flags.drop_control_dependency.bind(),
parsed_flags.drop_control_dependency.default_value(),
"If true, ignore control dependency requirements in input TensorFlow "
"GraphDef. Otherwise an error will be raised upon control dependency "
"inputs."),
};
bool asked_for_help =
*argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
@ -163,6 +170,7 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags,
READ_TOCO_FLAG(drop_fake_quant, FlagRequirement::kNone);
READ_TOCO_FLAG(reorder_across_fake_quant, FlagRequirement::kNone);
READ_TOCO_FLAG(allow_custom_ops, FlagRequirement::kNone);
READ_TOCO_FLAG(drop_control_dependency, FlagRequirement::kNone);
// Deprecated flag handling.
if (parsed_toco_flags.input_type.specified()) {

View File

@ -36,7 +36,7 @@ enum FileFormat {
// are not normally encoded in model files and in general may not be thought
// of as properties of models, instead describing how models are to be
// processed in the context of the present tooling job.
// Next Id: 12
// Next Id: 13
message TocoFlags {
// Input file format
optional FileFormat input_format = 1;
@ -128,4 +128,12 @@ message TocoFlags {
// If true, allow TOCO to create TF Lite Custom operators for all the
// unsupported Tensorflow ops.
optional bool allow_custom_ops = 10;
// Applies only to the case when the input format is TENSORFLOW_GRAPHDEF.
// If true, then control dependencies will be immediately dropped during
// import.
// If not set, the default behavior is as follows:
// - Default to false if the output format is TENSORFLOW_GRAPHDEF.
// - Default to true in all other cases.
optional bool drop_control_dependency = 12;
}

View File

@ -85,38 +85,57 @@ void MakeGeneralGraphTransformationsSet(
transformations->Add(new MakeInitialDequantizeOperator);
}
void SetArrayFinalDataTypes(const TocoFlags& toco_flags, Model* model) {
const bool output_supports_only_float =
toco_flags.output_format() == TENSORFLOW_GRAPHDEF;
bool SupportsQuantization(FileFormat format) {
return (format == GRAPHVIZ_DOT || format == TFLITE);
;
}
ArrayDataType specified_final_data_type = ArrayDataType::kNone;
bool SupportsFusedActivationFunction(FileFormat format) {
return (format == GRAPHVIZ_DOT || format == TFLITE);
}
bool SupportsLstmCell(FileFormat format) {
return (format == TENSORFLOW_GRAPHDEF || format == GRAPHVIZ_DOT);
}
bool SupportsPreallocatedWorkspace(FileFormat format) {
return (format == GRAPHVIZ_DOT || format == TFLITE);
}
bool IsRealValued(toco::ArrayDataType type) {
return static_cast<bool>(type == toco::ArrayDataType::kFloat ||
type == toco::ArrayDataType::kUint8);
}
void SetFinalDataTypeOnInputs(const TocoFlags& toco_flags, Model* model) {
const FileFormat output_format = toco_flags.output_format();
ArrayDataType type;
if (toco_flags.has_inference_input_type()) {
specified_final_data_type =
ConvertIODataTypeToArrayDataType(toco_flags.inference_input_type());
type = ConvertIODataTypeToArrayDataType(toco_flags.inference_input_type());
} else if (toco_flags.has_inference_type()) {
specified_final_data_type =
ConvertIODataTypeToArrayDataType(toco_flags.inference_type());
}
ArrayDataType final_data_type = ArrayDataType::kNone;
if (output_supports_only_float) {
QCHECK(specified_final_data_type == ArrayDataType::kNone ||
specified_final_data_type == ArrayDataType::kFloat);
final_data_type = ArrayDataType::kFloat;
type = ConvertIODataTypeToArrayDataType(toco_flags.inference_type());
} else if (!SupportsQuantization(output_format)) {
// Data type is implicitly float for non-quantized formats
type = ArrayDataType::kFloat;
} else {
final_data_type = specified_final_data_type;
// Nothing to do. Data types stay as-is.
return;
}
for (int i = 0; i < model->flags.input_arrays_size(); i++) {
auto* array = model->arrays[model->flags.input_arrays(i).name()].get();
string const& array_name = model->flags.input_arrays(i).name();
auto* array = model->arrays[array_name].get();
// Note that the notion of changing data types only applies to real-numbers
// arrays (see the documentation for inference_input_type).
// TODO(benoitjacob) this is assuming that uint8 arrays are quantized,
// i.e. represent real numbers by means of quantization parameters,
// and not plain integer uint8 input arrays.
const bool is_real_numbers = array->data_type == ArrayDataType::kFloat ||
array->data_type == ArrayDataType::kUint8;
if (is_real_numbers) {
array->final_data_type = final_data_type;
if (!IsRealValued(array->data_type)) {
// Ignore non-real data types.
continue;
}
array->final_data_type = type;
}
}
@ -127,9 +146,16 @@ std::unique_ptr<Model> Import(const TocoFlags& toco_flags,
const string& input_file_contents) {
std::unique_ptr<Model> model;
switch (toco_flags.input_format()) {
case TENSORFLOW_GRAPHDEF:
model = ImportTensorFlowGraphDef(model_flags, input_file_contents);
case TENSORFLOW_GRAPHDEF: {
TensorFlowImportFlags tf_import_flags;
tf_import_flags.drop_control_dependency =
toco_flags.has_drop_control_dependency()
? toco_flags.drop_control_dependency()
: (toco_flags.output_format() != TENSORFLOW_GRAPHDEF);
model = ImportTensorFlowGraphDef(model_flags, tf_import_flags,
input_file_contents);
break;
}
case TFLITE:
model = toco::tflite::Import(model_flags, input_file_contents);
ResolveModelFlags(model_flags, model.get());
@ -148,23 +174,21 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
const FileFormat output_format = toco_flags.output_format();
const IODataType inference_type = toco_flags.inference_type();
const bool output_is_tflite = output_format == TFLITE;
const bool quantize_output =
SupportsQuantization(output_format) && inference_type == QUANTIZED_UINT8;
const bool output_is_tflite_quantized =
output_is_tflite && inference_type == QUANTIZED_UINT8;
if (output_is_tflite_quantized) {
if (quantize_output) {
QCHECK_NE(toco_flags.inference_input_type(), FLOAT)
<< "Quantized inference is not allowed with float inputs.";
}
SetArrayFinalDataTypes(toco_flags, model);
SetFinalDataTypeOnInputs(toco_flags, model);
GraphTransformationsSet transformations;
MakeGeneralGraphTransformationsSet(&transformations);
auto* remove_trivial_reshape = new RemoveTrivialReshape;
transformations.Add(remove_trivial_reshape);
if (output_format == TFLITE) {
if (SupportsFusedActivationFunction(output_format)) {
transformations.Add(new FuseActivationFunctions);
} else {
transformations.Add(new UnfuseActivationFunctions);
@ -183,25 +207,24 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
// easy to pass a new toco flag. Once that is resolved on the DarwiNN
// tests side, the special-casing of DarwiNN here can go away.
// TODO(benoitjacob): so drop it when we can.
if ((output_is_tflite_quantized &&
toco_flags.reorder_across_fake_quant())) {
if ((quantize_output && toco_flags.reorder_across_fake_quant())) {
transformations.Add(new DropFakeQuant);
}
}
transformations.Add(new ConvertPureConvToDepthwise);
// TFLite export does not yet support fused LSTM cell.
if (output_format == TENSORFLOW_GRAPHDEF) {
if (SupportsLstmCell(output_format)) {
transformations.Add(new IdentifyLstmCell);
}
transformations.Add(new ResolveConstantConcatenation);
RunGraphTransformations(model, "general graph transformations",
transformations);
if (output_is_tflite_quantized) {
if (quantize_output) {
RunGraphTransformations(model, "pre-quantization graph transformations",
{new HardcodeMinMax, new DropFakeQuant});
}
if (output_is_tflite_quantized) {
if (quantize_output) {
if (toco_flags.has_default_ranges_min() &&
toco_flags.has_default_ranges_max()) {
UseDefaultMinMaxRangeValues(model, toco_flags.default_ranges_min(),
@ -232,7 +255,7 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
CheckUnsupportedOperations(*model);
}
if (output_is_tflite) {
if (SupportsPreallocatedWorkspace(output_format)) {
AllocateTransientArrays(model, kDefaultTransientDataAlignment);
LogDump(kLogLevelModelChanged, "AFTER ALLOCATION", *model);
}

View File

@ -294,6 +294,7 @@ void LogArray(int log_level, const Model& model, const string& name) {
VLOG(log_level) << "Array: " << name;
switch (array.data_type) {
case ArrayDataType::kNone:
VLOG(log_level) << " Data type:";
break;
case ArrayDataType::kFloat:
VLOG(log_level) << " Data type: kFloat";
@ -309,6 +310,24 @@ void LogArray(int log_level, const Model& model, const string& name) {
<< static_cast<int>(array.data_type) << ")";
break;
}
switch (array.final_data_type) {
case ArrayDataType::kNone:
VLOG(log_level) << " Final type:";
break;
case ArrayDataType::kFloat:
VLOG(log_level) << " Final type: kFloat";
break;
case ArrayDataType::kInt32:
VLOG(log_level) << " Final type: kInt32";
break;
case ArrayDataType::kUint8:
VLOG(log_level) << " Final type: kUint8";
break;
default:
VLOG(log_level) << " Final type: other (numerical value: "
<< static_cast<int>(array.data_type) << ")";
break;
}
if (array.buffer) {
VLOG(log_level) << " Constant Buffer";
}
@ -1016,7 +1035,6 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) {
}
RESOLVE_MODEL_FLAG(variable_batch)
RESOLVE_MODEL_FLAG(drop_control_dependency)
#undef RESOLVE_MODEL_FLAG
@ -1044,12 +1062,6 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) {
"--output_arrays flag must be given on the command-line.";
for (const auto& input_array_proto : model->flags.input_arrays()) {
QCHECK(!input_array_proto.shape().empty())
<< "This model does not have shape defined for input array "
<< input_array_proto.name()
<< ", so one must be specified by a non-empty --input_shape "
"command-line flag.";
auto& input_array = model->GetOrCreateArray(input_array_proto.name());
if (input_array_proto.has_data_type()) {
const ArrayDataType specified_type =
@ -1072,6 +1084,14 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) {
input_array.data_type = ArrayDataType::kFloat;
}
if (!input_array.has_shape()) {
QCHECK(!input_array_proto.shape().empty())
<< "This model does not have shape defined for input array "
<< input_array_proto.name()
<< ", so one must be specified by a non-empty --input_shape "
"command-line flag.";
}
// Compare/merge the model->flags describing the input_shape with
// the actual input array's shape.
auto& input_array_dims = *input_array.mutable_shape()->mutable_dims();
@ -1563,7 +1583,11 @@ void CheckFinalDataTypesSatisfied(const Model& model) {
for (const auto& array_entry : model.arrays) {
const auto& array = *array_entry.second;
if (array.final_data_type != ArrayDataType::kNone) {
CHECK(array.final_data_type == array.data_type);
CHECK(array.final_data_type == array.data_type)
<< "Array \"" << array_entry.first
<< "\" has mis-matching actual and final data types ("
<< static_cast<int>(array.data_type) << ","
<< static_cast<int>(array.final_data_type) << ").";
}
}
}

View File

@ -1,4 +1,4 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -26,7 +26,6 @@ from tensorflow.contrib.opt.python.training.lazy_adam_optimizer import *
from tensorflow.contrib.opt.python.training.moving_average_optimizer import *
from tensorflow.contrib.opt.python.training.multitask_optimizer_wrapper import *
from tensorflow.contrib.opt.python.training.nadam_optimizer import *
from tensorflow.contrib.opt.python.training.nadam_optimizer import *
from tensorflow.contrib.opt.python.training.powersign import *
from tensorflow.contrib.opt.python.training.variable_clipping_optimizer import *
# pylint: enable=wildcard-import
@ -35,12 +34,18 @@ from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
'PowerSignOptimizer', 'AddSignOptimizer'
'PowerSignOptimizer',
'AddSignOptimizer'
'DelayCompensatedGradientDescentOptimizer',
'DropStaleGradientOptimizer', 'ExternalOptimizerInterface',
'LazyAdamOptimizer', 'NadamOptimizer', 'MovingAverageOptimizer',
'ScipyOptimizerInterface', 'VariableClippingOptimizer',
'MultitaskOptimizerWrapper', 'clip_gradients_by_global_norm',
'DropStaleGradientOptimizer',
'ExternalOptimizerInterface',
'LazyAdamOptimizer',
'NadamOptimizer',
'MovingAverageOptimizer',
'ScipyOptimizerInterface',
'VariableClippingOptimizer',
'MultitaskOptimizerWrapper',
'clip_gradients_by_global_norm',
]
remove_undocumented(__name__, _allowed_symbols)

View File

@ -12,9 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""An optimizer wrapper that ensures correct behaviour
of stateful optimizers with multitask loss."""
"""An optimizer wrapper for stateful optimizers with multitask loss."""
from __future__ import absolute_import
from __future__ import division
@ -30,26 +28,27 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.training import optimizer
__all__ = ["MultitaskOptimizerWrapper",
"clip_gradients_by_global_norm"]
__all__ = ['MultitaskOptimizerWrapper', 'clip_gradients_by_global_norm']
def _is_all_zeros(grad):
all_zeros = math_ops.equal(math_ops.count_nonzero(grad), 0)
return all_zeros
def _get_wrapper(fn, opt):
def wrapper(self, grad, *args, **kwargs): # pylint: disable=unused-argument
all_zeros = _is_all_zeros(grad)
return control_flow_ops.cond(
all_zeros,
control_flow_ops.no_op,
lambda: fn(grad, *args, **kwargs))
return control_flow_ops.cond(all_zeros, control_flow_ops.no_op,
lambda: fn(grad, *args, **kwargs))
wrapper = types.MethodType(wrapper, opt)
return wrapper
class MultitaskOptimizerWrapper(object):
"""Optimizer wrapper that ensures that
all-zero gradients don't affect the optimizer state.
"""Optimizer wrapper making all-zero gradients harmless.
This might be useful when a multi-task loss is used,
and some components of the loss might be
@ -88,20 +87,20 @@ class MultitaskOptimizerWrapper(object):
gradvars_clipped, global_step=batch)
```
"""
def __init__(self, opt):
"""
"""Constructor.
Args:
opt: an instance of a class that implements tf.train.Optimizer.
opt: an instance of a class that implements tf.train.Optimizer.
"""
if not isinstance(opt, optimizer.Optimizer):
raise TypeError(
"Supplied optimizer must be an instance of tf.train.Optimizer")
'Supplied optimizer must be an instance of tf.train.Optimizer')
self._opt = opt
overriden_methods = ('_apply_dense',
'_resource_apply_dense',
'_apply_sparse',
'_resource_apply_sparse')
for name in overriden_methods:
overridden_methods = ('_apply_dense', '_resource_apply_dense',
'_apply_sparse', '_resource_apply_sparse')
for name in overridden_methods:
fn = getattr(self._opt, name)
wrapper = _get_wrapper(fn, self._opt)
setattr(self._opt, name, wrapper)
@ -112,27 +111,30 @@ class MultitaskOptimizerWrapper(object):
def clip_gradients_by_global_norm(gradients_variables, clip_norm=20.):
"""Clips gradients of a multitask loss by their global norm.
Ignores all-zero tensors when computing the global norm.
Args:
gradients_variables: a list of pairs (gradient, variable).
clip_norm: a float Tensor, the global norm to clip on. Default is 20.0.
gradients_variables: a list of pairs (gradient, variable).
clip_norm: a float Tensor, the global norm to clip on. Default is 20.0.
Returns:
list: A list of pairs of the same type as gradients_variables,.
fixed_global_norm: A 0-D (scalar) Tensor representing the global norm.
list: A list of pairs of the same type as gradients_variables,.
fixed_global_norm: A 0-D (scalar) Tensor representing the global norm.
"""
gradients, variables = six.moves.zip(*gradients_variables)
def _replace_nonexisting_grad(grad):
if grad is None:
return grad
all_zeros = _is_all_zeros(grad)
return control_flow_ops.cond(all_zeros,
lambda: array_ops.zeros(
[], dtype=dtypes.as_dtype(grad.dtype)),
lambda: grad)
return control_flow_ops.cond(
all_zeros,
lambda: array_ops.zeros([], dtype=dtypes.as_dtype(grad.dtype)),
lambda: grad)
nonzero_gradients = [_replace_nonexisting_grad(g) for g in gradients]
fixed_global_norm = clip_ops.global_norm(nonzero_gradients)
gradients, _ = clip_ops.clip_by_global_norm(gradients, clip_norm,
use_norm=fixed_global_norm)
gradients, _ = clip_ops.clip_by_global_norm(
gradients, clip_norm, use_norm=fixed_global_norm)
return list(six.moves.zip(gradients, variables)), fixed_global_norm

View File

@ -18,6 +18,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import six
from tensorflow.contrib.opt.python.training import multitask_optimizer_wrapper
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -25,13 +28,11 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import momentum
import numpy as np
import six
class MultitaskOptimizerWrapperTest(test.TestCase):
"""Tests for the multitask optimizer wrapper.
"""
Tests for the multitask optimizer wrapper.
"""
def testWrapper(self):
with self.test_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtypes.float32)
@ -39,12 +40,10 @@ class MultitaskOptimizerWrapperTest(test.TestCase):
grads0 = constant_op.constant([0.1, 0.1], dtype=dtypes.float32)
grads1 = constant_op.constant([0.01, 0.01], dtype=dtypes.float32)
grads_allzero = constant_op.constant([0.0, 0.0], dtype=dtypes.float32)
mom_opt_impl = momentum.MomentumOptimizer(
learning_rate=2.0, momentum=0.9)
mom_opt_impl = momentum.MomentumOptimizer(learning_rate=2.0, momentum=0.9)
mom_opt = multitask_optimizer_wrapper.MultitaskOptimizerWrapper(
mom_opt_impl)
mom_update = mom_opt.apply_gradients(
zip([grads0, grads1], [var0, var1]))
mom_update = mom_opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
mom_update_partial = mom_opt.apply_gradients(
zip([grads_allzero, grads1], [var0, var1]))
mom_update_no_action = mom_opt.apply_gradients(
@ -63,14 +62,13 @@ class MultitaskOptimizerWrapperTest(test.TestCase):
# Step 1: normal momentum update.
self.evaluate(mom_update)
# Check that the momentum accumulators have been updated.
self.assertAllCloseAccordingToType(np.array([0.1, 0.1]),
self.evaluate(slot0))
self.assertAllCloseAccordingToType(np.array([0.01, 0.01]),
self.evaluate(slot1))
self.assertAllCloseAccordingToType(
np.array([0.1, 0.1]), self.evaluate(slot0))
self.assertAllCloseAccordingToType(
np.array([0.01, 0.01]), self.evaluate(slot1))
# Check that the parameters have been updated.
self.assertAllCloseAccordingToType(
np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]),
self.evaluate(var0))
np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]), self.evaluate(var0))
self.assertAllCloseAccordingToType(
np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]),
self.evaluate(var1))
@ -78,8 +76,8 @@ class MultitaskOptimizerWrapperTest(test.TestCase):
# Step 2: momentum update that changes only slot1 but not slot0.
self.evaluate(mom_update_partial)
# Check that only the relevant momentum accumulator has been updated.
self.assertAllCloseAccordingToType(np.array([0.1, 0.1]),
self.evaluate(slot0))
self.assertAllCloseAccordingToType(
np.array([0.1, 0.1]), self.evaluate(slot0))
self.assertAllCloseAccordingToType(
np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]),
self.evaluate(slot1))
@ -87,8 +85,8 @@ class MultitaskOptimizerWrapperTest(test.TestCase):
# Step 3: momentum update that does not change anything.
self.evaluate(mom_update_no_action)
# Check that the momentum accumulators have *NOT* been updated.
self.assertAllCloseAccordingToType(np.array([0.1, 0.1]),
self.evaluate(slot0))
self.assertAllCloseAccordingToType(
np.array([0.1, 0.1]), self.evaluate(slot0))
self.assertAllCloseAccordingToType(
np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]),
self.evaluate(slot1))
@ -105,8 +103,9 @@ class MultitaskOptimizerWrapperTest(test.TestCase):
grads3 = None
varlist = [var0, var1, var2, var3]
gradients = [grads0, grads1, grads2, grads3]
clipped_gradvars, global_norm = multitask_optimizer_wrapper.clip_gradients_by_global_norm(
six.moves.zip(gradients, varlist), clip_norm=1.0)
clipped_gradvars, global_norm = (
multitask_optimizer_wrapper.clip_gradients_by_global_norm(
six.moves.zip(gradients, varlist), clip_norm=1.0))
clipped_grads = list(six.moves.zip(*clipped_gradvars))[0]
reference_global_norm = np.sqrt(np.sum(np.square([10.0, 15.0, 0.0, 5.0])))
self.assertAllCloseAccordingToType(
@ -115,5 +114,6 @@ class MultitaskOptimizerWrapperTest(test.TestCase):
self.evaluate(clipped_grads[2]), np.array([0., 0.]))
self.assertEqual(clipped_grads[3], None)
if __name__ == "__main__":
test.main()

View File

@ -24,6 +24,7 @@ import numpy as np
from tensorflow.contrib import rnn as contrib_rnn
from tensorflow.contrib.rnn.python.ops import core_rnn_cell
from tensorflow.contrib.rnn.python.ops import rnn_cell as contrib_rnn_cell
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -374,19 +375,20 @@ class RNNCellTest(test.TestCase):
h = array_ops.zeros([batch_size, num_proj])
state = rnn_cell_impl.LSTMStateTuple(c, h)
cell = contrib_rnn_cell.LayerNormLSTMCell(
num_units=num_units,
num_proj=num_proj,
forget_bias=1.0,
layer_norm=True,
norm_gain=1.0,
norm_shift=0.0)
num_units=num_units,
num_proj=num_proj,
forget_bias=1.0,
layer_norm=True,
norm_gain=1.0,
norm_shift=0.0)
g, out_m = cell(x, state)
sess.run([variables_lib.global_variables_initializer()])
res = sess.run([g, out_m], {
x.name: np.ones((batch_size, input_size)),
c.name: 0.1 * np.ones((batch_size, num_units)),
h.name: 0.1 * np.ones((batch_size, num_proj))
})
res = sess.run(
[g, out_m], {
x.name: np.ones((batch_size, input_size)),
c.name: 0.1 * np.ones((batch_size, num_units)),
h.name: 0.1 * np.ones((batch_size, num_proj))
})
self.assertEqual(len(res), 2)
# The numbers in results were not calculated, this is mostly just a
# smoke test.
@ -396,9 +398,9 @@ class RNNCellTest(test.TestCase):
# Different inputs so different outputs and states
for i in range(1, batch_size):
self.assertTrue(
float(np.linalg.norm((res[0][0, :] - res[0][i, :]))) < 1e-6)
float(np.linalg.norm((res[0][0, :] - res[0][i, :]))) < 1e-6)
self.assertTrue(
float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) < 1e-6)
float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) < 1e-6)
def testOutputProjectionWrapper(self):
with self.test_session() as sess:

View File

@ -996,26 +996,19 @@ class RNNCellTest(test.TestCase):
output, state = cell(x, hidden)
sess.run([variables.global_variables_initializer()])
res = sess.run([output, state], {
hidden[0].name:
np.array([[[[[1.],[1.]],
[[1.],[1.]]],
[[[1.],[1.]],
[[1.],[1.]]]],
[[[[2.],[2.]],
[[2.],[2.]]],
[[[2.],[2.]],
[[2.],[2.]]]]]),
x.name:
np.array([[[[[1.],[1.]],
[[1.],[1.]]],
[[[1.],[1.]],
[[1.],[1.]]]],
[[[[2.],[2.]],
[[2.],[2.]]],
[[[2.],[2.]],
[[2.],[2.]]]]])
})
res = sess.run(
[output, state], {
hidden[0].name:
np.array([[[[[1.], [1.]], [[1.], [1.]]], [[[1.], [1.]], [[
1.
], [1.]]]], [[[[2.], [2.]], [[2.], [2.]]],
[[[2.], [2.]], [[2.], [2.]]]]]),
x.name:
np.array([[[[[1.], [1.]], [[1.], [1.]]], [[[1.], [1.]], [[
1.
], [1.]]]], [[[[2.], [2.]], [[2.], [2.]]], [[[2.], [2.]],
[[2.], [2.]]]]])
})
# This is a smoke test, making sure expected values are unchanged.
self.assertEqual(len(res), 2)
self.assertAllClose(res[0], res[1].h)
@ -1276,10 +1269,8 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
self.assertAllClose(res[2].c, expected_c1, 1e-5)
self.assertAllClose(res[2].h, expected_h1, 1e-5)
def testBasicLSTMCellWithStateTupleLayerNorm(self):
"""The results of LSTMCell and LayerNormBasicLSTMCell
should be same. """
"""The results of LSTMCell and LayerNormBasicLSTMCell should be the same."""
with self.test_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
@ -1290,21 +1281,21 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
c1 = array_ops.zeros([1, 2])
h1 = array_ops.zeros([1, 2])
state1 = rnn_cell_impl.LSTMStateTuple(c1, h1)
cell = rnn_cell_impl.MultiRNNCell(
[contrib_rnn_cell.LayerNormLSTMCell(
2,
layer_norm=True,
norm_gain=1.0,
norm_shift=0.0) for _ in range(2)])
cell = rnn_cell_impl.MultiRNNCell([
contrib_rnn_cell.LayerNormLSTMCell(
2, layer_norm=True, norm_gain=1.0, norm_shift=0.0)
for _ in range(2)
])
h, (s0, s1) = cell(x, (state0, state1))
sess.run([variables.global_variables_initializer()])
res = sess.run([h, s0, s1], {
x.name: np.array([[1., 1.]]),
c0.name: 0.1 * np.asarray([[0, 1]]),
h0.name: 0.1 * np.asarray([[2, 3]]),
c1.name: 0.1 * np.asarray([[4, 5]]),
h1.name: 0.1 * np.asarray([[6, 7]]),
})
res = sess.run(
[h, s0, s1], {
x.name: np.array([[1., 1.]]),
c0.name: 0.1 * np.asarray([[0, 1]]),
h0.name: 0.1 * np.asarray([[2, 3]]),
c1.name: 0.1 * np.asarray([[4, 5]]),
h1.name: 0.1 * np.asarray([[6, 7]]),
})
expected_h = np.array([[-0.38079708, 0.38079708]])
expected_h0 = np.array([[-0.38079708, 0.38079708]])

View File

@ -115,7 +115,6 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
The class uses optional peep-hole connections, and an optional projection
layer.
Layer normalization implementation is based on:
https://arxiv.org/abs/1607.06450.
@ -124,15 +123,24 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton
and is applied before the internal nonlinearities.
"""
def __init__(self, num_units, use_peepholes=False,
initializer=None, num_proj=None, proj_clip=None,
num_unit_shards=1, num_proj_shards=1,
forget_bias=1.0, state_is_tuple=True,
activation=math_ops.tanh, reuse=None,
layer_norm=False, norm_gain=1.0, norm_shift=0.0):
def __init__(self,
num_units,
use_peepholes=False,
initializer=None,
num_proj=None,
proj_clip=None,
num_unit_shards=1,
num_proj_shards=1,
forget_bias=1.0,
state_is_tuple=True,
activation=math_ops.tanh,
reuse=None,
layer_norm=False,
norm_gain=1.0,
norm_shift=0.0):
"""Initialize the parameters for an LSTM cell.
Args:
@ -164,8 +172,6 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
`layer_norm` has been set to `False`, this argument will be ignored.
norm_shift: float, The layer normalization shift initial value. If
`layer_norm` has been set to `False`, this argument will be ignored.
"""
super(CoupledInputForgetGateLSTMCell, self).__init__(_reuse=reuse)
if not state_is_tuple:
@ -2049,8 +2055,8 @@ class ConvLSTMCell(rnn_cell_impl.RNNCell):
if self._skip_connection:
self._total_output_channels += self._input_shape[-1]
state_size = tensor_shape.TensorShape(self._input_shape[:-1]
+ [self._output_channels])
state_size = tensor_shape.TensorShape(
self._input_shape[:-1] + [self._output_channels])
self._state_size = rnn_cell_impl.LSTMStateTuple(state_size, state_size)
self._output_size = tensor_shape.TensorShape(self._input_shape[:-1]
+ [self._total_output_channels])
@ -2110,11 +2116,8 @@ class Conv3DLSTMCell(ConvLSTMCell):
"""Construct Conv3DLSTM. See `ConvLSTMCell` for more details."""
super(Conv3DLSTMCell, self).__init__(conv_ndims=3, **kwargs)
def _conv(args,
filter_size,
num_features,
bias,
bias_start=0.0):
def _conv(args, filter_size, num_features, bias, bias_start=0.0):
"""convolution:
Args:
args: a Tensor or a list of Tensors of dimension 3D, 4D or 5D,
@ -2391,12 +2394,19 @@ class LayerNormLSTMCell(rnn_cell_impl.RNNCell):
"""
def __init__(self, num_units,
use_peepholes=False, cell_clip=None,
initializer=None, num_proj=None, proj_clip=None,
def __init__(self,
num_units,
use_peepholes=False,
cell_clip=None,
initializer=None,
num_proj=None,
proj_clip=None,
forget_bias=1.0,
activation=None, layer_norm=False,
norm_gain=1.0, norm_shift=0.0, reuse=None):
activation=None,
layer_norm=False,
norm_gain=1.0,
norm_shift=0.0,
reuse=None):
"""Initialize the parameters for an LSTM cell.
Args:
@ -2457,7 +2467,6 @@ class LayerNormLSTMCell(rnn_cell_impl.RNNCell):
def output_size(self):
return self._output_size
def _linear(self,
args,
output_size,
@ -2507,9 +2516,9 @@ class LayerNormLSTMCell(rnn_cell_impl.RNNCell):
scope = vs.get_variable_scope()
with vs.variable_scope(scope) as outer_scope:
weights = vs.get_variable(
"kernel", [total_arg_size, output_size],
dtype=dtype,
initializer=kernel_initializer)
"kernel", [total_arg_size, output_size],
dtype=dtype,
initializer=kernel_initializer)
if len(args) == 1:
res = math_ops.matmul(args[0], weights)
else:
@ -2521,9 +2530,7 @@ class LayerNormLSTMCell(rnn_cell_impl.RNNCell):
if bias_initializer is None:
bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype)
biases = vs.get_variable(
"bias", [output_size],
dtype=dtype,
initializer=bias_initializer)
"bias", [output_size], dtype=dtype, initializer=bias_initializer)
if not layer_norm:
res = nn_ops.bias_add(res, biases)
@ -2554,7 +2561,6 @@ class LayerNormLSTMCell(rnn_cell_impl.RNNCell):
ValueError: If input size cannot be inferred from inputs via
static shape inference.
"""
num_proj = self._num_units if self._num_proj is None else self._num_proj
sigmoid = math_ops.sigmoid
(c_prev, m_prev) = state
@ -2567,10 +2573,14 @@ class LayerNormLSTMCell(rnn_cell_impl.RNNCell):
with vs.variable_scope(scope, initializer=self._initializer) as unit_scope:
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
lstm_matrix = self._linear([inputs, m_prev], 4 * self._num_units, bias=True,
bias_initializer=None, layer_norm=self._layer_norm)
lstm_matrix = self._linear(
[inputs, m_prev],
4 * self._num_units,
bias=True,
bias_initializer=None,
layer_norm=self._layer_norm)
i, j, f, o = array_ops.split(
value=lstm_matrix, num_or_size_splits=4, axis=1)
value=lstm_matrix, num_or_size_splits=4, axis=1)
if self._layer_norm:
i = _norm(self._norm_gain, self._norm_shift, i, "input")
@ -2580,20 +2590,22 @@ class LayerNormLSTMCell(rnn_cell_impl.RNNCell):
# Diagonal connections
if self._use_peepholes:
with vs.variable_scope(unit_scope) as projection_scope:
with vs.variable_scope(unit_scope):
w_f_diag = vs.get_variable(
"w_f_diag", shape=[self._num_units], dtype=dtype)
"w_f_diag", shape=[self._num_units], dtype=dtype)
w_i_diag = vs.get_variable(
"w_i_diag", shape=[self._num_units], dtype=dtype)
"w_i_diag", shape=[self._num_units], dtype=dtype)
w_o_diag = vs.get_variable(
"w_o_diag", shape=[self._num_units], dtype=dtype)
"w_o_diag", shape=[self._num_units], dtype=dtype)
if self._use_peepholes:
c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
sigmoid(i + w_i_diag * c_prev) * self._activation(j))
c = (
sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
sigmoid(i + w_i_diag * c_prev) * self._activation(j))
else:
c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) *
self._activation(j))
c = (
sigmoid(f + self._forget_bias) * c_prev +
sigmoid(i) * self._activation(j))
if self._layer_norm:
c = _norm(self._norm_gain, self._norm_shift, c, "state")
@ -2608,7 +2620,7 @@ class LayerNormLSTMCell(rnn_cell_impl.RNNCell):
m = sigmoid(o) * self._activation(c)
if self._num_proj is not None:
with vs.variable_scope("projection") as proj_scope:
with vs.variable_scope("projection"):
m = self._linear(m, self._num_proj, bias=False)
if self._proj_clip is not None:

View File

@ -192,7 +192,8 @@ class _BaseAttentionMechanism(AttentionMechanism):
raise TypeError("probability_fn must be callable, saw type: %s" %
type(probability_fn).__name__)
if score_mask_value is None:
score_mask_value = dtypes.as_dtype(self._memory_layer.dtype).as_numpy_dtype(-np.inf)
score_mask_value = dtypes.as_dtype(
self._memory_layer.dtype).as_numpy_dtype(-np.inf)
self._probability_fn = lambda score, prev: ( # pylint:disable=g-long-lambda
probability_fn(
_maybe_mask_score(score, memory_sequence_length, score_mask_value),
@ -1145,7 +1146,9 @@ class AttentionWrapper(rnn_cell_impl.RNNCell):
% (len(attention_layer_sizes), len(attention_mechanisms)))
self._attention_layers = tuple(
layers_core.Dense(
attention_layer_size, name="attention_layer", use_bias=False,
attention_layer_size,
name="attention_layer",
use_bias=False,
dtype=attention_mechanisms[i].dtype)
for i, attention_layer_size in enumerate(attention_layer_sizes))
self._attention_layer_size = sum(attention_layer_sizes)

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifdef TENSORFLOW_USE_VERBS
#include "tensorflow/contrib/verbs/rdma.h"
#include <fcntl.h>
#include <cstdlib>
#include <fcntl.h>
#include "tensorflow/contrib/verbs/verbs_util.h"
@ -137,7 +138,7 @@ ibv_device* set_device() {
if (!env_p_rdma_device.empty()) {
for (device_index = 0; device_index < dev_num; device_index++) {
if (!env_p_rdma_device.compare(
ibv_get_device_name(dev_list[device_index]))) {
ibv_get_device_name(dev_list[device_index]))) {
CHECK(get_dev_active_port_count(dev_list[device_index]) != 0)
<< "Device " << ibv_get_device_name(dev_list[device_index])
<< " has no active ports";
@ -147,7 +148,7 @@ ibv_device* set_device() {
// check validity of input device
CHECK(false) << "The device " << env_p_rdma_device << " wasn't found";
} else {
// set default device
// set default device
str_port_num = get_env_var("RDMA_DEVICE_PORT");
CHECK(str_port_num.empty())
<< "RDMA_DEVICE should be provided if RDMA_DEVICE_PORT is set by user";
@ -177,7 +178,7 @@ ibv_device* set_device() {
// Returns:
// port to use
uint8_t set_port(ibv_context* context) {
uint8_t port_num = 0; //0 is illegal port number
uint8_t port_num = 0; // 0 is illegal port number
string str_port_num;
ibv_device_attr device_att;
ibv_port_attr port_attr;
@ -199,9 +200,7 @@ uint8_t set_port(ibv_context* context) {
// check if port id active
CHECK(port_attr.state == IBV_PORT_ACTIVE)
<< "Selected RDMA_DEVICE_PORT is not active";
}
// set default port
else {
} else { // set default port
for (port_index = 1; port_index <= device_att.phys_port_cnt; port_index++) {
rc = ibv_query_port(context, port_index, &port_attr);
CHECK(!rc) << "Failed to query the port" << port_index;
@ -269,7 +268,7 @@ bool is_gid_type_roce_v2(ibv_context* context, uint8_t port_num,
// Function to set GID index.
// If the port link is IB, no GID index should be selected.
// If Ethernet but RDMA_GID_INDEX not set gid index that supports
// RoCE V2 will be chosen(fails if more then one IP is configured)
// RoCE V2 will be chosen(fails if more than one IP is configured)
// Args:
// context - device context
// port_num - port number
@ -302,7 +301,7 @@ uint8_t set_gid(uint8_t port_num, ibv_context* context) {
}
}
switch (port_attr.link_layer) {
case(IBV_LINK_LAYER_ETHERNET) :
case (IBV_LINK_LAYER_ETHERNET):
gid_str = get_env_var("RDMA_GID_INDEX");
if (!gid_str.empty()) {
gid_index = stoi(gid_str);
@ -313,7 +312,7 @@ uint8_t set_gid(uint8_t port_num, ibv_context* context) {
<< "More than one IP is available, please specify GID_INDEX";
}
break;
case(IBV_LINK_LAYER_INFINIBAND) : // no need in GID index
case (IBV_LINK_LAYER_INFINIBAND): // no need in GID index
break;
default:
LOG(INFO) << "Unknown port link layer. Currently supporting Ethernet and "
@ -374,7 +373,8 @@ enum ibv_mtu set_mtu(uint8_t port_num, ibv_context* context) {
break;
default:
CHECK(0) << "Error: MTU input value must be one of the following: 256, "
"512, 1024, 2048, 4096. MTU " << mtu << " is invalid\n";
"512, 1024, 2048, 4096. MTU "
<< mtu << " is invalid\n";
break;
}
CHECK(mtu < port_attr.active_mtu)

View File

@ -921,7 +921,7 @@ Status InferenceContext::Add(DimensionHandle first, DimensionOrConstant second,
if (first_value == 0) {
*out = MakeDim(second);
} else if (second_value == 0) {
*out = MakeDim(first);
*out = first;
} else if (first_value == kUnknownDim || second_value == kUnknownDim) {
*out = UnknownDim();
} else {
@ -946,7 +946,7 @@ Status InferenceContext::Subtract(DimensionHandle first,
const int64 second_value = Value(second);
// Special cases.
if (second_value == 0) {
*out = MakeDim(first);
*out = first;
} else if (first_value == kUnknownDim || second_value == kUnknownDim) {
*out = UnknownDim();
} else {

View File

@ -455,7 +455,6 @@ class Graph {
// the corresponding NodeDef to reflect the change.
// REQUIRES: The control edge must exist.
void RemoveControlEdge(const Edge* e);
// Updates the input to a node. The existing edge to `dst` is removed and an
// edge from `new_src` to `dst` is created. The NodeDef associated with `dst`
// is also updated.

View File

@ -118,11 +118,9 @@ class GraphTest : public ::testing::Test {
LOG(FATAL) << name;
}
bool ControlEdgeExistsInGraphOrNodeDef(const Node* src,
const Node* dst) {
for (const Edge *e : dst->in_edges()) {
if (e->IsControlEdge() &&
e->src() == src &&
bool ControlEdgeExistsInGraphOrNodeDef(const Node* src, const Node* dst) {
for (const Edge* e : dst->in_edges()) {
if (e->IsControlEdge() && e->src() == src &&
e->src_output() == Graph::kControlSlot &&
e->dst_input() == Graph::kControlSlot) {
return true;

View File

@ -702,12 +702,16 @@ Status GraphProperties::UpdateShapes(SymbolicShapeRefiner* shape_refiner,
Status GraphProperties::PropagateShapes(
SymbolicShapeRefiner* shape_refiner, bool relax, TopoQueue* new_shapes,
const std::unordered_map<const Node*, std::unordered_set<const Node*>>&
resources) const {
resources,
int num_loops) const {
// Limit the number of iterations to prevent infinite loops in the presence of
// incorrect shape functions. The algoritm should converge in at most
// num_nested_loops^2 * max_rank. We approximate max_rank with the constant 4.
// The same applies to resources.
const int64 num_loops = new_shapes->size();
VLOG(1) << "Propagating (relax=" << relax << ") " << new_shapes->size()
<< " new shapes through " << num_loops << " loops and "
<< resources.size() << " resources" << std::endl;
const int64 max_loop_length = item_.graph.node_size();
const int64 max_rank = 4;
const int64 max_loop_iterations =
@ -721,9 +725,12 @@ Status GraphProperties::PropagateShapes(
while (!new_shapes->empty() &&
num_loop_iterations++ < max_loop_iterations) {
const Node* n = new_shapes->pop();
for (const Node* fanout : n->out_nodes()) {
TF_RETURN_IF_ERROR(
UpdateShapes(shape_refiner, relax, fanout, new_shapes));
for (const Edge* e : n->out_edges()) {
if (!e->IsControlEdge()) {
const Node* fanout = e->dst();
TF_RETURN_IF_ERROR(
UpdateShapes(shape_refiner, relax, fanout, new_shapes));
}
}
}
@ -818,6 +825,7 @@ Status GraphProperties::InferStatically() {
std::unordered_map<const Node*, std::unordered_set<const Node*>> resources;
std::unordered_set<const Node*> enter_nodes;
std::unordered_set<const Node*> merge_nodes;
int num_loops = 0;
for (const Node* const node : graph.nodes()) {
for (int i = 0; i < node->num_inputs(); ++i) {
if (node->input_type(i) == DataType::DT_RESOURCE) {
@ -830,6 +838,8 @@ Status GraphProperties::InferStatically() {
enter_nodes.insert(node);
} else if (node->IsMerge()) {
merge_nodes.insert(node);
} else if (node->IsNextIteration()) {
++num_loops;
}
}
@ -853,7 +863,7 @@ Status GraphProperties::InferStatically() {
}
// Propagate shapes normally.
TF_RETURN_IF_ERROR(
PropagateShapes(&refiner, relax, &new_shapes, resources));
PropagateShapes(&refiner, relax, &new_shapes, resources, num_loops));
}
// Track shapes globally across the graph.
@ -906,6 +916,9 @@ Status GraphProperties::InferStatically() {
&input_properties[i]);
}
for (const auto& edge : node->in_edges()) {
if (edge->IsControlEdge()) {
continue;
}
if (!edge->src()->IsConstant()) {
continue;
}

View File

@ -108,7 +108,8 @@ class GraphProperties {
Status PropagateShapes(
SymbolicShapeRefiner* shape_refiner, bool relax, TopoQueue* new_shapes,
const std::unordered_map<const Node*, std::unordered_set<const Node*>>&
resources) const;
resources,
int num_loops) const;
};
} // end namespace grappler

View File

@ -24,64 +24,40 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
bool IsAdd(const NodeDef& node) {
const auto op = node.op();
return op == "Add";
}
bool IsAdd(const NodeDef& node) { return node.op() == "Add"; }
bool IsAddN(const NodeDef& node) {
const auto op = node.op();
return op == "AddN";
}
bool IsAddN(const NodeDef& node) { return node.op() == "AddN"; }
bool IsAvgPoolGrad(const NodeDef& node) {
const auto op = node.op();
return op == "AvgPoolGrad";
}
bool IsAvgPoolGrad(const NodeDef& node) { return node.op() == "AvgPoolGrad"; }
bool IsBiasAddGrad(const NodeDef& node) {
const auto op = node.op();
return op == "BiasAddGrad";
}
bool IsAssert(const NodeDef& node) { return node.op() == "Assert"; }
bool IsConcatOffset(const NodeDef& node) {
const auto op = node.op();
return op == "ConcatOffset";
}
bool IsBiasAddGrad(const NodeDef& node) { return node.op() == "BiasAddGrad"; }
bool IsConstant(const NodeDef& node) {
const auto op = node.op();
return op == "Const";
}
bool IsConcatOffset(const NodeDef& node) { return node.op() == "ConcatOffset"; }
bool IsConv2D(const NodeDef& node) {
const auto op = node.op();
return op == "Conv2D";
}
bool IsConstant(const NodeDef& node) { return node.op() == "Const"; }
bool IsConv2D(const NodeDef& node) { return node.op() == "Conv2D"; }
bool IsConv2DBackpropFilter(const NodeDef& node) {
const auto op = node.op();
return op == "Conv2DBackpropFilter";
return node.op() == "Conv2DBackpropFilter";
}
bool IsConv2DBackpropInput(const NodeDef& node) {
const auto op = node.op();
return op == "Conv2DBackpropInput";
return node.op() == "Conv2DBackpropInput";
}
bool IsDepthwiseConv2dNative(const NodeDef& node) {
const auto op = node.op();
return op == "DepthwiseConv2dNative";
return node.op() == "DepthwiseConv2dNative";
}
bool IsDepthwiseConv2dNativeBackpropFilter(const NodeDef& node) {
const auto op = node.op();
return op == "DepthwiseConv2dNativeBackpropFilter";
return node.op() == "DepthwiseConv2dNativeBackpropFilter";
}
bool IsDepthwiseConv2dNativeBackpropInput(const NodeDef& node) {
const auto op = node.op();
return op == "DepthwiseConv2dNativeBackpropInput";
return node.op() == "DepthwiseConv2dNativeBackpropInput";
}
bool IsDequeueOp(const NodeDef& node) {
@ -101,14 +77,10 @@ bool IsExit(const NodeDef& node) {
return op == "Exit" || op == "RefExit";
}
bool IsFloorMod(const NodeDef& node) {
const auto& op = node.op();
return op == "FloorMod";
}
bool IsFloorMod(const NodeDef& node) { return node.op() == "FloorMod"; }
bool IsFusedBatchNormGradV1(const NodeDef& node) {
const auto& op = node.op();
return op == "FusedBatchNormGrad";
return node.op() == "FusedBatchNormGrad";
}
bool IsIdentity(const NodeDef& node) {
@ -121,25 +93,16 @@ bool IsMerge(const NodeDef& node) {
return op == "Merge" || op == "RefMerge";
}
bool IsMul(const NodeDef& node) {
const auto op = node.op();
return op == "Mul";
}
bool IsMul(const NodeDef& node) { return node.op() == "Mul"; }
bool IsNoOp(const NodeDef& node) {
const auto op = node.op();
return op == "NoOp";
}
bool IsNoOp(const NodeDef& node) { return node.op() == "NoOp"; }
bool IsNextIteration(const NodeDef& node) {
const auto& op = node.op();
return op == "NextIteration" || op == "RefNextIteration";
}
bool IsPad(const NodeDef& node) {
const auto op = node.op();
return op == "Pad";
}
bool IsPad(const NodeDef& node) { return node.op() == "Pad"; }
bool IsPlaceholder(const NodeDef& node) {
const auto op = node.op();
@ -147,20 +110,11 @@ bool IsPlaceholder(const NodeDef& node) {
op == "PlaceholderWithDefault";
}
bool IsRealDiv(const NodeDef& node) {
const auto op = node.op();
return op == "RealDiv";
}
bool IsRealDiv(const NodeDef& node) { return node.op() == "RealDiv"; }
bool IsReluGrad(const NodeDef& node) {
const auto op = node.op();
return op == "ReluGrad";
}
bool IsReluGrad(const NodeDef& node) { return node.op() == "ReluGrad"; }
bool IsRecv(const NodeDef& node) {
const auto op = node.op();
return op == "_Recv";
}
bool IsRecv(const NodeDef& node) { return node.op() == "_Recv"; }
bool IsReduction(const NodeDef& node) {
const auto& op = node.op();
@ -175,53 +129,34 @@ bool IsRestore(const NodeDef& node) {
node.op() == "RestoreSlice");
}
bool IsSend(const NodeDef& node) {
const auto op = node.op();
return op == "_Send";
}
bool IsSend(const NodeDef& node) { return node.op() == "_Send"; }
bool IsSlice(const NodeDef& node) {
const auto op = node.op();
return op == "Slice";
}
bool IsSlice(const NodeDef& node) { return node.op() == "Slice"; }
bool IsSquaredDifference(const NodeDef& node) {
const auto op = node.op();
return op == "SquaredDifference";
return node.op() == "SquaredDifference";
}
bool IsSqueeze(const NodeDef& node) {
const auto op = node.op();
return op == "Squeeze";
}
bool IsSqueeze(const NodeDef& node) { return node.op() == "Squeeze"; }
bool IsStopGradient(const NodeDef& node) {
const auto& op = node.op();
return op == "StopGradient" || op == "PreventGradient";
}
bool IsSub(const NodeDef& node) {
const auto op = node.op();
return op == "Sub";
}
bool IsSub(const NodeDef& node) { return node.op() == "Sub"; }
bool IsSum(const NodeDef& node) {
const auto op = node.op();
return op == "Sum";
}
bool IsSum(const NodeDef& node) { return node.op() == "Sum"; }
bool IsSwitch(const NodeDef& node) {
const auto& op = node.op();
return op == "Switch" || op == "RefSwitch";
}
bool IsTranspose(const NodeDef& node) {
const auto op = node.op();
return op == "Transpose";
}
bool IsTranspose(const NodeDef& node) { return node.op() == "Transpose"; }
bool IsVariable(const NodeDef& node) {
const auto op = node.op();
const auto& op = node.op();
return op == "Variable" || op == "VariableV2" || op == "AutoReloadVariable" ||
op == "VarHandleOp" || op == "ReadVariableOp";
}

View File

@ -25,6 +25,7 @@ namespace grappler {
bool IsAdd(const NodeDef& node);
bool IsAddN(const NodeDef& node);
bool IsAvgPoolGrad(const NodeDef& node);
bool IsAssert(const NodeDef& node);
bool IsBiasAddGrad(const NodeDef& node);
bool IsConcatOffset(const NodeDef& node);
bool IsConstant(const NodeDef& node);

View File

@ -448,6 +448,10 @@ bool ArithmeticOptimizer::CanDedup(const NodeDef& node) const {
if (node.device().find("SPU") != string::npos) {
return false;
}
// Workaround for Assert mistakenly being labeled as stateful.
if (IsAssert(node)) {
return true;
}
return IsFreeOfSideEffect(node);
}

View File

@ -81,6 +81,38 @@ TEST_F(ArithmeticOptimizerTest, OpDedupping) {
EXPECT_EQ("c1", new_mul.input(1));
}
TEST_F(ArithmeticOptimizerTest, OpDeduppingAssertAndCheckNumerics) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output p = ops::Placeholder(s, DT_BOOL, ops::Placeholder::Shape({}));
Output c = ops::Const(s.WithOpName("c"), {3.14, 2.7}, {1, 2});
auto check1 = ops::CheckNumerics(s.WithOpName("check1"), c, "foo");
auto check2 = ops::CheckNumerics(s.WithOpName("check2"), c, "foo");
auto assert1 = ops::Assert(s.WithOpName("assert1"), p, {c});
auto assert2 = ops::Assert(s.WithOpName("assert2"), p, {c});
Output mul = ops::Multiply(s.WithOpName("mul").WithControlDependencies(
{assert1.operation, assert2.operation}),
check1, check2);
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
ArithmeticOptimizer optimizer;
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
// Run the optimizer twice to make sure the rewrite is idempotent.
item.graph.Swap(&output);
status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
EXPECT_EQ(5, output.node_size());
const NodeDef& new_mul = output.node(3);
EXPECT_EQ(4, new_mul.input_size());
EXPECT_EQ("check1", new_mul.input(0));
EXPECT_EQ("check1", new_mul.input(1));
EXPECT_EQ("^assert1", new_mul.input(2));
EXPECT_EQ("^assert1", new_mul.input(3));
}
TEST_F(ArithmeticOptimizerTest, OpDedupCommutative) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output c1 = ops::Const(s.WithOpName("c1"), {1.0f, 2.0f}, {1, 2});

View File

@ -1720,6 +1720,7 @@ tf_cuda_cc_tests(
":data_flow",
":ops_testutil",
":ops_util",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",

View File

@ -97,8 +97,9 @@ class BincountOp : public OpKernel {
const Tensor& weights_t = ctx->input(2);
int32 size = size_tensor.scalar<int32>()();
OP_REQUIRES(ctx, size >= 0, errors::InvalidArgument(
"size (", size, ") must be non-negative"));
OP_REQUIRES(
ctx, size >= 0,
errors::InvalidArgument("size (", size, ") must be non-negative"));
const auto arr = arr_t.flat<int32>();
const auto weights = weights_t.flat<T>();

View File

@ -16,11 +16,11 @@ limitations under the License.
#ifndef TENSORFLOW_BINCOUNT_OP_H_
#define TENSORFLOW_BINCOUNT_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
namespace tensorflow {

View File

@ -17,12 +17,12 @@ limitations under the License.
#define EIGEN_USE_GPU
#include "tensorflow/core/kernels/bincount_op.h"
#include "external/cub_archive/cub/device/device_histogram.cuh"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/bincount_op.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
@ -93,8 +93,8 @@ struct BincountFunctor<GPUDevice, T> {
/* num_samples */ num_samples,
/* stream */ stream);
if (err != cudaSuccess) {
return errors::Internal("Could not launch HistogramEven: ",
cudaGetErrorString(err), ".");
return errors::Internal(
"Could not launch HistogramEven: ", cudaGetErrorString(err), ".");
}
return Status::OK();
}

View File

@ -30,8 +30,8 @@ static Graph* Bincount(int arr_size, int nbins) {
Tensor arr(DT_INT32, TensorShape({arr_size}));
arr.flat<int32>() = arr.flat<int32>().setRandom().abs();
Tensor size(DT_INT32, TensorShape({(int32)1}));
size.flat<int32>()(0) = (int32)nbins;
Tensor size(DT_INT32, TensorShape({static_cast<int32>(1)}));
size.flat<int32>()(0) = static_cast<int32>(nbins);
Tensor weights(DT_INT32, TensorShape({0}));

View File

@ -77,10 +77,10 @@ struct BucketizeFunctor<GPUDevice, T> {
TF_RETURN_IF_ERROR(boundaries_array.Finalize());
CudaLaunchConfig config = GetCudaLaunchConfig(input.size(), d);
BucketizeCustomKernel<
T><<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
input.size(), input.data(), boundaries_vector.size(),
boundaries_array.data(), output.data());
BucketizeCustomKernel<T>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
input.size(), input.data(), boundaries_vector.size(),
boundaries_array.data(), output.data());
return Status::OK();
}

View File

@ -1101,29 +1101,27 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
bool cudnn_use_autotune_;
};
#define REGISTER_GPU_KERNEL(T) \
REGISTER_KERNEL_BUILDER( \
Name("Conv3DBackpropInput").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
Conv3DBackpropInputOp<GPUDevice, T>); \
REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2") \
.Device(DEVICE_GPU) \
.TypeConstraint<T>("T") \
.HostMemory("input_sizes"), \
Conv3DBackpropInputOp<GPUDevice, T>); \
.Device(DEVICE_GPU) \
.TypeConstraint<T>("T") \
.HostMemory("input_sizes"), \
Conv3DBackpropInputOp<GPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("Conv3DBackpropFilter").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
Conv3DBackpropFilterOp<GPUDevice, T>); \
Name("Conv3DBackpropFilter").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
Conv3DBackpropFilterOp<GPUDevice, T>); \
REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \
.Device(DEVICE_GPU) \
.TypeConstraint<T>("T") \
.HostMemory("filter_sizes"), \
Conv3DBackpropFilterOp<GPUDevice, T>);
.Device(DEVICE_GPU) \
.TypeConstraint<T>("T") \
.HostMemory("filter_sizes"), \
Conv3DBackpropFilterOp<GPUDevice, T>);
TF_CALL_half(REGISTER_GPU_KERNEL);
TF_CALL_float(REGISTER_GPU_KERNEL);
#undef REGISTER_GPU_KERNEL
#endif // GOOGLE_CUDA
} // namespace tensorflow

View File

@ -1,4 +1,4 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -22,7 +22,7 @@ REGISTER4(UnaryOp, CPU, "Asinh", functor::asinh, float, double,
#ifdef TENSORFLOW_USE_SYCL
REGISTER2(UnaryOp, SYCL, "Asinh", functor::asinh, float, double);
#endif // TENSORFLOW_USE_SYCL
#endif // TENSORFLOW_USE_SYCL
#if GOOGLE_CUDA
REGISTER2(UnaryOp, GPU, "Asinh", functor::asinh, float, double);

View File

@ -22,7 +22,7 @@ REGISTER4(UnaryOp, CPU, "Atanh", functor::atanh, float, double,
#ifdef TENSORFLOW_USE_SYCL
REGISTER2(UnaryOp, SYCL, "Atanh", functor::atanh, float, double);
#endif // TENSORFLOW_USE_SYCL
#endif // TENSORFLOW_USE_SYCL
#if GOOGLE_CUDA
REGISTER2(UnaryOp, GPU, "Atanh", functor::atanh, float, double);

View File

@ -231,7 +231,8 @@ static void CopyOutputBackpropRegion(const DepthwiseArgs& args,
}
// Pad to vector-register width (if needed).
for (int64 d = 0; d < pad_size; ++d) {
buffer[buf_base + vectorized_size + scalar_size + d] = static_cast<T>(0);
buffer[buf_base + vectorized_size + scalar_size + d] =
static_cast<T>(0);
}
}
}
@ -510,7 +511,8 @@ static void DepthwiseConvBackpropInputReference(const DepthwiseArgs& args,
#if GOOGLE_CUDA
extern template struct LaunchDepthwiseConvBackpropInputOp<GPUDevice, Eigen::half>;
extern template struct LaunchDepthwiseConvBackpropInputOp<GPUDevice,
Eigen::half>;
extern template struct LaunchDepthwiseConvBackpropInputOp<GPUDevice, float>;
extern template struct LaunchDepthwiseConvBackpropInputOp<GPUDevice, double>;
@ -885,7 +887,8 @@ static void DepthwiseConvBackpropFilterReference(const DepthwiseArgs& args,
#if GOOGLE_CUDA
extern template struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice, Eigen::half>;
extern template struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice,
Eigen::half>;
extern template struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice, float>;
extern template struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice, double>;

View File

@ -158,7 +158,8 @@ struct DepthwiseFilterPadOp {
}
// Pad the remainder of output to vector-register boundary.
for (int64 j = 0; j < pad_size; ++j) {
padded_filter[output_base + vectorized_size + scalar_size + j] = static_cast<T>(0);
padded_filter[output_base + vectorized_size + scalar_size + j] =
static_cast<T>(0);
}
}
}

View File

@ -73,18 +73,22 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
std::move(other_arguments),
&captured_func));
*output = new Dataset(input, std::move(captured_func), cycle_length,
block_length, output_types_, output_shapes_);
*output =
new Dataset(ctx, input, func_, std::move(captured_func), cycle_length,
block_length, output_types_, output_shapes_);
}
private:
class Dataset : public DatasetBase {
class Dataset : public GraphDatasetBase {
public:
Dataset(const DatasetBase* input,
Dataset(OpKernelContext* ctx, const DatasetBase* input,
const NameAttrList& func,
std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length,
int64 block_length, const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
: input_(input),
: GraphDatasetBase(ctx),
input_(input),
func_(func),
captured_func_(std::move(captured_func)),
cycle_length_(cycle_length),
block_length_(block_length),
@ -110,13 +114,47 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
string DebugString() override { return "InterleaveDatasetOp::Dataset"; }
protected:
Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
Node** output) const override {
TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name()));
Node* input_node;
TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_node));
Node* cycle_length_node;
TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node));
Node* block_length_node;
TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node));
DataTypeVector other_arguments_types;
other_arguments_types.reserve(captured_func_->captured_inputs().size());
std::vector<NodeBuilder::NodeOut> other_arguments;
other_arguments.reserve(captured_func_->captured_inputs().size());
for (const Tensor& t : captured_func_->captured_inputs()) {
Node* node;
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
other_arguments.emplace_back(node);
other_arguments_types.emplace_back(t.dtype());
}
AttrValue f;
b->BuildAttrValue(func_, &f);
AttrValue other_arguments_types_attr;
b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
TF_RETURN_IF_ERROR(b->AddDataset(
this,
{{0, input_node}, {2, cycle_length_node}, {3, block_length_node}},
{{1, other_arguments}},
{{"f", f}, {"Targuments", other_arguments_types_attr}}, output));
return Status::OK();
}
private:
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
input_impl_(params.dataset->input_->MakeIterator(params.prefix)),
current_elements_(params.dataset->cycle_length_) {}
current_elements_(params.dataset->cycle_length_),
args_list_(params.dataset->cycle_length_) {}
void AdvanceToNextInCycle() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
block_index_ = 0;
@ -150,18 +188,19 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
// We have reached the end of the current element, so move
// on to the next element in the cycle.
current_elements_[cycle_index_].reset();
args_list_[cycle_index_].clear();
--num_open_;
AdvanceToNextInCycle();
} else if (!end_of_input_) {
// Get the next element from the input dataset, and create
// an iterator from it.
std::vector<Tensor> args;
TF_RETURN_IF_ERROR(
input_impl_->GetNext(ctx, &args, &end_of_input_));
TF_RETURN_IF_ERROR(input_impl_->GetNext(
ctx, &args_list_[cycle_index_], &end_of_input_));
if (!end_of_input_) {
TF_RETURN_IF_ERROR(dataset::MakeIteratorFromInputElement(
ctx, args, cycle_index_, dataset()->captured_func_.get(),
prefix(), &current_elements_[cycle_index_]));
ctx, args_list_[cycle_index_], cycle_index_,
dataset()->captured_func_.get(), prefix(),
&current_elements_[cycle_index_]));
++num_open_;
}
} else {
@ -173,11 +212,100 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("cycle_index"), cycle_index_));
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("block_index"), block_index_));
if (end_of_input_) {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("end_of_input"), ""));
}
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("num_open"), num_open_));
TF_RETURN_IF_ERROR(SaveCurrentElements(writer));
return Status::OK();
}
Status RestoreInternal(OpKernelContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
int64 cycle_index;
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name("cycle_index"), &cycle_index));
cycle_index_ = size_t(cycle_index);
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name("block_index"), &block_index_));
if (reader->Contains(full_name("end_of_input"))) end_of_input_ = true;
int64 num_open;
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name("num_open"), &num_open));
num_open_ = size_t(num_open);
TF_RETURN_IF_ERROR(RestoreCurrentElements(ctx, reader));
return Status::OK();
}
private:
Status SaveCurrentElements(IteratorStateWriter* writer)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
for (int idx = 0; idx < current_elements_.size(); idx++) {
if (current_elements_[idx]) {
TF_RETURN_IF_ERROR(SaveParent(writer, current_elements_[idx]));
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat("args_size[", idx, "]")),
args_list_[idx].size()));
for (int i = 0; i < args_list_[idx].size(); i++) {
TF_RETURN_IF_ERROR(writer->WriteTensor(
full_name(strings::StrCat("args_list_[", idx, "][", i, "]")),
args_list_[idx][i]));
}
}
}
return Status::OK();
}
Status RestoreCurrentElements(OpKernelContext* ctx,
IteratorStateReader* reader)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
IteratorContext::Params params;
params.env = ctx->env();
params.runner = *(ctx->runner());
IteratorContext iter_ctx(std::move(params));
for (int idx = 0; idx < current_elements_.size(); idx++) {
if (reader->Contains(
full_name(strings::StrCat("args_size[", idx, "]")))) {
int64 args_size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat("args_size[", idx, "]")),
&args_size));
args_list_[idx].resize(args_size);
for (int i = 0; i < args_size; i++) {
TF_RETURN_IF_ERROR(reader->ReadTensor(
full_name(strings::StrCat("args_list_[", idx, "][", i, "]")),
&args_list_[idx][i]));
}
TF_RETURN_IF_ERROR(dataset::MakeIteratorFromInputElement(
&iter_ctx, args_list_[idx], idx,
dataset()->captured_func_.get(), prefix(),
&current_elements_[idx]));
TF_RETURN_IF_ERROR(
RestoreParent(ctx, reader, current_elements_[idx]));
} else {
current_elements_[idx].reset();
}
}
return Status::OK();
}
mutex mu_;
const std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
std::vector<std::unique_ptr<IteratorBase>> current_elements_
GUARDED_BY(mu_);
std::vector<std::vector<Tensor>> args_list_ GUARDED_BY(mu_);
size_t cycle_index_ GUARDED_BY(mu_) = 0;
int64 block_index_ GUARDED_BY(mu_) = 0;
bool end_of_input_ GUARDED_BY(mu_) = false;
@ -185,6 +313,7 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
};
const DatasetBase* const input_;
const NameAttrList func_;
const std::unique_ptr<CapturedFunction> captured_func_;
const int64 cycle_length_;
const int64 block_length_;

View File

@ -258,7 +258,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
EnsureOutputAllocated(batch_result, result->return_values);
const size_t num_components = result->return_values.size();
for (size_t i = 0; i < num_components; ++i) {
Tensor tensor = result->return_values[i];
const Tensor& tensor = result->return_values[i];
Tensor* batch = &(batch_result->output)[i];
if (tensor.NumElements() !=
(batch->NumElements() / batch->dim_size(0))) {
@ -271,6 +271,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
", [batch]: ", batch_shape.DebugString()));
break;
}
// TODO(mrry): Add a version of DoParallelConcat that allows
// us to move `tensor` where possible, to speed up string
// tensor batching.
Status copy_status = ::tensorflow::functor::DoParallelConcat(
*dataset()->device_, tensor, offset, batch);
if (!copy_status.ok()) {
@ -279,6 +282,11 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
}
}
// NOTE(mrry): We clear the return values here to release any
// memory associated with them and to paralellize the destruction
// of the tensors (which can be surprisingly expensive for
// map functions with large numbers of return values).
result->return_values.clear();
batch_result->counter->DecrementCount();
});
}
@ -297,7 +305,10 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
for (size_t i = 0; i < dataset()->batch_size_; ++i) {
size_t index = ComputeInvocationIndex(batch_index, i);
InvocationResult* result = &invocation_results_[index];
*result = InvocationResult();
// Reset the state of `result`.
// NOTE(mrry): `result->return_values` were cleared when the previous
// invocation completed.
result->status = Status::OK();
}
// Start individual invocations.
for (size_t i = 0; i < dataset()->batch_size_; ++i) {

View File

@ -359,7 +359,8 @@ class MaxPoolingGradOp<Eigen::GpuDevice, T> : public OpKernel {
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
use_dnn_ = CanUseCudnn();
ReadBoolFromEnvVar("TF_ENABLE_MAXPOOL_NANPROP", false, &propagate_nans_);
TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_MAXPOOL_NANPROP", false,
&propagate_nans_));
}
void Compute(OpKernelContext* context) override {
@ -888,7 +889,8 @@ class MaxPoolingWithArgmaxOp : public OpKernel {
errors::Unimplemented(
"Pooling is not yet supported on the batch dimension."));
ReadBoolFromEnvVar("TF_ENABLE_MAXPOOL_NANPROP", false, &propagate_nans_);
TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_MAXPOOL_NANPROP", false,
&propagate_nans_));
}
void Compute(OpKernelContext* context) override {
@ -1052,7 +1054,8 @@ class MaxPoolingNoMaskOp<GPUDevice, T> : public OpKernel {
"Pooling is not yet supported on the batch dimension."));
use_dnn_ = CanUseCudnn();
ReadBoolFromEnvVar("TF_ENABLE_MAXPOOL_NANPROP", false, &propagate_nans_);
TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_MAXPOOL_NANPROP", false,
&propagate_nans_));
}
void Compute(OpKernelContext* context) override {
@ -1137,7 +1140,8 @@ class MaxPoolingNoMaskV2Op<GPUDevice, T> : public OpKernel {
}
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
use_dnn_ = CanUseCudnn();
ReadBoolFromEnvVar("TF_ENABLE_MAXPOOL_NANPROP", false, &propagate_nans_);
TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_MAXPOOL_NANPROP", false,
&propagate_nans_));
}
void Compute(OpKernelContext* context) override {

View File

@ -405,17 +405,17 @@ bool MaxPoolForwardWithOptionalArgmax<T>::operator()(
if (propagate_nans) {
MaxPoolForwardNHWC<true>
<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
kThreadsPerBlock, 0, d.stream()>>>
(output_size, bottom_data, height, width, channels, pooled_height,
pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l,
top_data, mask);
kThreadsPerBlock, 0, d.stream()>>>(
output_size, bottom_data, height, width, channels, pooled_height,
pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l,
top_data, mask);
} else {
MaxPoolForwardNHWC<false>
<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
kThreadsPerBlock, 0, d.stream()>>>
(output_size, bottom_data, height, width, channels, pooled_height,
pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l,
top_data, mask);
kThreadsPerBlock, 0, d.stream()>>>(
output_size, bottom_data, height, width, channels, pooled_height,
pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l,
top_data, mask);
}
return d.ok();
}

View File

@ -101,8 +101,8 @@ class MklToTfOp : public OpKernel {
// Allocate output tensor.
TensorShape output_shape = input_shape.GetTfShape();
Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(input_number,
output_shape, &output_tensor));
OP_REQUIRES_OK(context, context->allocate_output(
input_number, output_shape, &output_tensor));
CHECK_NOTNULL(output_tensor);
// Do we need to reorder Mkl layout into TensorFlow layout?
@ -116,13 +116,13 @@ class MklToTfOp : public OpKernel {
// If not, just forward input tensor to output tensor.
CHECK(output_tensor->CopyFrom(input_tensor, output_shape));
}
} catch (mkldnn::error &e) {
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + std::string(e.message) +
", in file " + std::string(__FILE__) + ":" +
std::to_string(__LINE__);
OP_REQUIRES_OK(context,
errors::Aborted("Operation received an exception:", error_msg));
", message: " + std::string(e.message) + ", in file " +
std::string(__FILE__) + ":" + std::to_string(__LINE__);
OP_REQUIRES_OK(
context,
errors::Aborted("Operation received an exception:", error_msg));
}
}
#else
@ -160,8 +160,8 @@ class MklToTfOp : public OpKernel {
// Allocate output tensor.
Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(input_number,
output_shape, &output_tensor));
OP_REQUIRES_OK(context, context->allocate_output(input_number, output_shape,
&output_tensor));
dnnLayout_t output_layout =
static_cast<dnnLayout_t>(input_shape.GetTfLayout());

View File

@ -98,6 +98,19 @@ gtl::InlinedVector<T, 8> ComputeStride(const TensorShape& shape) {
return strides;
}
// Helper to compute 'strides' given an Eigen TensorDimensions
template <typename T, typename EigenDimensions>
gtl::InlinedVector<T, 8> ComputeEigenStrides(const EigenDimensions& shape) {
const int ndims = shape.rank();
gtl::InlinedVector<T, 8> strides(ndims);
T stride = 1;
for (int i = ndims - 1; i >= 0; --i) {
strides[i] = stride;
stride *= static_cast<T>(shape[i]);
}
return strides;
}
} // namespace tensorflow
#endif // TENSORFLOW_KERNELS_OPS_UTIL_H_

View File

@ -181,16 +181,18 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
padding_values.push_back(tensor::DeepCopy(padding_value_t));
}
*output = new Dataset(batch_size, std::move(padded_shapes),
*output = new Dataset(ctx, batch_size, std::move(padded_shapes),
std::move(padding_values), input);
}
private:
class Dataset : public DatasetBase {
class Dataset : public GraphDatasetBase {
public:
Dataset(int64 batch_size, std::vector<PartialTensorShape> padded_shapes,
Dataset(OpKernelContext* ctx, int64 batch_size,
std::vector<PartialTensorShape> padded_shapes,
std::vector<Tensor> padding_values, const DatasetBase* input)
: batch_size_(batch_size),
: GraphDatasetBase(ctx),
batch_size_(batch_size),
padded_shapes_(std::move(padded_shapes)),
padding_values_(std::move(padding_values)),
input_(input) {
@ -232,6 +234,47 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
")::Dataset");
}
protected:
Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
Node** output) const override {
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node));
Node* batch_size = nullptr;
TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size));
std::vector<NodeBuilder::NodeOut> padded_shapes;
padded_shapes.reserve(padded_shapes_.size());
for (int i = 0; i < padded_shapes_.size(); i++) {
Node* node;
Tensor t(DT_INT64, TensorShape({padded_shapes_[i].dims()}));
for (int j = 0; j < padded_shapes_[i].dims(); j++) {
t.vec<int64>()(j) = padded_shapes_[i].dim_size(j);
}
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
padded_shapes.emplace_back(node);
}
std::vector<NodeBuilder::NodeOut> padding_values;
padding_values.reserve(padding_values_.size());
for (const Tensor& t : padding_values_) {
Node* node;
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
padding_values.emplace_back(node);
}
AttrValue output_types;
b->BuildAttrValue(output_dtypes(), &output_types);
AttrValue N;
b->BuildAttrValue<int64>(padded_shapes_.size(), &N);
TF_RETURN_IF_ERROR(
b->AddDataset(this, {{0, input_graph_node}, {1, batch_size}},
{{2, padded_shapes}, {3, padding_values}},
{{"Toutput_types", output_types}, {"N", N}}, output));
return Status::OK();
}
private:
// Copies element into the index^th slice of parent (in the 0th dimension).
//
@ -248,17 +291,25 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
// Each row of `batch_elements` is a tuple of tensors from the
// input iterator.
std::vector<std::vector<Tensor>> batch_elements;
batch_elements.reserve(dataset()->batch_size_);
{
mutex_lock l(mu_);
*end_of_sequence = false;
for (int i = 0; i < dataset()->batch_size_ && !*end_of_sequence;
++i) {
std::vector<Tensor> batch_element_tuple;
TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &batch_element_tuple,
end_of_sequence));
if (!*end_of_sequence) {
batch_elements.push_back(std::move(batch_element_tuple));
if (!input_impl_) {
*end_of_sequence = true;
return Status::OK();
} else {
*end_of_sequence = false;
batch_elements.reserve(dataset()->batch_size_);
for (int i = 0; i < dataset()->batch_size_ && !*end_of_sequence;
++i) {
std::vector<Tensor> batch_element_tuple;
TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &batch_element_tuple,
end_of_sequence));
if (!*end_of_sequence) {
batch_elements.push_back(std::move(batch_element_tuple));
}
}
if (*end_of_sequence) {
input_impl_.reset();
}
}
}
@ -347,6 +398,28 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
if (input_impl_)
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
else
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("exhausted"), ""));
return Status::OK();
}
Status RestoreInternal(OpKernelContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
if (reader->Contains(full_name("exhausted"))) {
input_impl_.reset();
} else {
input_impl_ = dataset()->input_->MakeIterator(prefix());
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
}
return Status::OK();
}
private:
mutex mu_;
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);

View File

@ -352,13 +352,15 @@ class DeserializeSparseOp : public OpKernel {
i, "] was: ", shape.dims() - 1, " but rank of SparseTensor[", i,
"] is: ", expanded_tensor_shape.dims() - 1));
for (int j = 1; j < shape.dims(); ++j) {
OP_REQUIRES(
context, shape.dim_size(j) == expanded_tensor_shape.dim_size(j),
errors::InvalidArgument(
"Inconsistent shape across SparseTensors: dimension ", j - 1,
" prior to SparseTensor[", i, "] was: ", shape.dim_size(j),
" but rank of SparseTensor[", i,
"] is: ", expanded_tensor_shape.dim_size(j)));
// NOTE(mrry): For compatibility with the implementations of
// DeserializeManySparse, and many ops that generate
// SparseTensors to batch that do not have a fixed
// dense_shape (e.g. `tf.parse_single_example()`), we
// compute the maximum in each dimension to find the
// smallest dense_shape that bounds all of the input
// SparseTensors.
shape.set_dim(j, std::max(shape.dim_size(j),
expanded_tensor_shape.dim_size(j)));
}
}
}

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/cloud/curl_http_request.h"
#include "tensorflow/core/platform/cloud/file_block_cache.h"
#include "tensorflow/core/platform/cloud/google_auth_provider.h"
@ -696,6 +697,18 @@ Status GcsFileSystem::LoadBufferFromGCS(const string& filename, size_t offset,
TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading gs://",
bucket, "/", object);
if (out->size() < block_size()) {
// Check stat cache to see if we encountered an interrupted read.
FileStatistics stat;
if (stat_cache_->Lookup(filename, &stat)) {
if (offset + out->size() < stat.length) {
return errors::Internal(strings::Printf(
"File contents are inconsistent for file: %s @ %lu.",
filename.c_str(), offset));
}
}
}
return Status::OK();
}
@ -816,7 +829,8 @@ Status GcsFileSystem::StatForObject(const string& fname, const string& bucket,
return errors::Internal("'stat' cannot be nullptr.");
}
if (object.empty()) {
return errors::InvalidArgument("'object' must be a non-empty string.");
return errors::InvalidArgument(strings::Printf(
"'object' must be a non-empty string. (File: %s)", fname.c_str()));
}
StatCache::ComputeFunc compute_func =

View File

@ -131,8 +131,8 @@ error::Code ErrnoToCode(int err_number) {
case ENETUNREACH: // Network unreachable
case ENOLCK: // No locks available
case ENOLINK: // Link has been severed
#if !(defined(__APPLE__) || defined(__FreeBSD__) || defined(_WIN32) \
|| defined(__HAIKU__))
#if !(defined(__APPLE__) || defined(__FreeBSD__) || defined(_WIN32) || \
defined(__HAIKU__))
case ENONET: // Machine is not on the network
#endif
code = error::UNAVAILABLE;

View File

@ -37,8 +37,8 @@ limitations under the License.
#ifdef TF_USE_SNAPPY
#include "snappy.h"
#endif
#if (defined(__APPLE__) && defined(__MACH__)) || defined(__FreeBSD__) \
|| defined(__HAIKU__)
#if (defined(__APPLE__) && defined(__MACH__)) || defined(__FreeBSD__) || \
defined(__HAIKU__)
#include <thread>
#endif
@ -62,8 +62,8 @@ int NumSchedulableCPUs() {
}
perror("sched_getaffinity");
#endif
#if (defined(__APPLE__) && defined(__MACH__)) || defined(__FreeBSD__) \
|| defined(__HAIKU__)
#if (defined(__APPLE__) && defined(__MACH__)) || defined(__FreeBSD__) || \
defined(__HAIKU__)
unsigned int count = std::thread::hardware_concurrency();
if (count > 0) return static_cast<int>(count);
#endif

Some files were not shown because too many files have changed in this diff Show More