Merge pull request #14814 from yifeif/branch_176709725

Branch 176709725
This commit is contained in:
Gunhan Gulsoy 2017-11-22 20:35:17 -08:00 committed by GitHub
commit ab0fcaceda
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
142 changed files with 3041 additions and 1538 deletions

View File

@ -905,6 +905,28 @@ def set_trisycl_include_dir(environ_cp):
write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR', write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR',
trisycl_include_dir) trisycl_include_dir)
def set_trisycl_include_dir(environ_cp):
"""Set TRISYCL_INCLUDE_DIR."""
ask_trisycl_include_dir = ('Please specify the location of the triSYCL '
'include directory. (Use --config=sycl_trisycl '
'when building with Bazel) '
'[Default is %s]: ') % (
_DEFAULT_TRISYCL_INCLUDE_DIR)
while True:
trisycl_include_dir = get_from_env_or_user_or_default(
environ_cp, 'TRISYCL_INCLUDE_DIR', ask_trisycl_include_dir,
_DEFAULT_TRISYCL_INCLUDE_DIR)
if os.path.exists(trisycl_include_dir):
break
print('Invalid triSYCL include directory, %s cannot be found' %
(trisycl_include_dir))
# Set TRISYCL_INCLUDE_DIR
environ_cp['TRISYCL_INCLUDE_DIR'] = trisycl_include_dir
write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR', trisycl_include_dir)
def set_mpi_home(environ_cp): def set_mpi_home(environ_cp):
"""Set MPI_HOME.""" """Set MPI_HOME."""
default_mpi_home = which('mpirun') or which('mpiexec') or '' default_mpi_home = which('mpirun') or which('mpiexec') or ''

View File

@ -189,7 +189,7 @@ def tf_library(name, graph, config,
" --cpp_class=" + cpp_class + " --cpp_class=" + cpp_class +
" --target_triple=" + target_llvm_triple() + " --target_triple=" + target_llvm_triple() +
" --out_session_module=$(@D)/" + session_module_pb + " --out_session_module=$(@D)/" + session_module_pb +
flags), " " + flags),
tools=[tfcompile_tool], tools=[tfcompile_tool],
visibility=visibility, visibility=visibility,
testonly=testonly, testonly=testonly,

View File

@ -76,7 +76,8 @@ class FusedBatchNormTest(XLATestCase):
# To avoid constant folding # To avoid constant folding
t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x") t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x")
scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale")
offset = array_ops.placeholder(np.float32, shape=scale_shape, name="offset") offset = array_ops.placeholder(
np.float32, shape=scale_shape, name="offset")
epsilon = 0.001 epsilon = 0.001
y_ref, mean_ref, var_ref = self._reference_training( y_ref, mean_ref, var_ref = self._reference_training(
x_val, scale_val, offset_val, epsilon, data_format) x_val, scale_val, offset_val, epsilon, data_format)
@ -112,7 +113,8 @@ class FusedBatchNormTest(XLATestCase):
# To avoid constant folding # To avoid constant folding
t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x") t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x")
scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale")
offset = array_ops.placeholder(np.float32, shape=scale_shape, name="offset") offset = array_ops.placeholder(
np.float32, shape=scale_shape, name="offset")
epsilon = 0.001 epsilon = 0.001
y, mean, var = nn.fused_batch_norm( y, mean, var = nn.fused_batch_norm(
t_val, t_val,

View File

@ -67,6 +67,15 @@ class Client {
std::vector<GlobalData*> arguments; std::vector<GlobalData*> arguments;
ExecutionOptions execution_options; ExecutionOptions execution_options;
ExecutionProfile* execution_profile; ExecutionProfile* execution_profile;
ComputationInstance(const Computation& computation,
std::vector<GlobalData*> arguments,
ExecutionOptions execution_options,
ExecutionProfile* execution_profile)
: computation(computation),
arguments(std::move(arguments)),
execution_options(execution_options),
execution_profile(execution_profile) {}
}; };
// Executes a list ComputationInstances and returns global data produced from // Executes a list ComputationInstances and returns global data produced from
@ -133,7 +142,7 @@ class Client {
// Returns a vector of global data handles that point to the tuple elements. // Returns a vector of global data handles that point to the tuple elements.
StatusOr<std::vector<std::unique_ptr<GlobalData>>> DeconstructTuple( StatusOr<std::vector<std::unique_ptr<GlobalData>>> DeconstructTuple(
const GlobalData& computation); const GlobalData& data);
// Retrieves the statistics of the given computation. // Retrieves the statistics of the given computation.
StatusOr<ComputationStats> GetComputationStats( StatusOr<ComputationStats> GetComputationStats(

View File

@ -85,9 +85,9 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault {
HloOpcode opcode) { HloOpcode opcode) {
HloComputation::Builder b("scalar_computation"); HloComputation::Builder b("scalar_computation");
auto scalar_lhs = b.AddInstruction(HloInstruction::CreateParameter( auto scalar_lhs = b.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {}), "scalar_lhs")); 0, ShapeUtil::MakeShape(primitive_type, {}), "scalar_lhs"));
auto scalar_rhs = b.AddInstruction(HloInstruction::CreateParameter( auto scalar_rhs = b.AddInstruction(HloInstruction::CreateParameter(
1, ShapeUtil::MakeShape(F32, {}), "scalar_rhs")); 1, ShapeUtil::MakeShape(primitive_type, {}), "scalar_rhs"));
auto scalar_op = b.AddInstruction( auto scalar_op = b.AddInstruction(
HloInstruction::CreateBinary(ShapeUtil::MakeShape(primitive_type, {}), HloInstruction::CreateBinary(ShapeUtil::MakeShape(primitive_type, {}),
opcode, scalar_lhs, scalar_rhs)); opcode, scalar_lhs, scalar_rhs));
@ -152,22 +152,30 @@ Status BatchNormRewriterVisitor::HandleBatchNormTraining(
// Expand batch norm training into smaller HLO ops. // Expand batch norm training into smaller HLO ops.
HloInstruction* operand = batch_norm->mutable_operand(0); HloInstruction* operand = batch_norm->mutable_operand(0);
const Shape operand_shape = operand->shape(); const Shape operand_shape = operand->shape();
PrimitiveType ptype = operand_shape.element_type();
int64 feature_index = batch_norm->feature_index(); int64 feature_index = batch_norm->feature_index();
const int64 feature_count = operand_shape.dimensions(feature_index); const int64 feature_count = operand_shape.dimensions(feature_index);
const int64 size_in_elements = ShapeUtil::ElementsIn(operand_shape); const int64 size_in_elements = ShapeUtil::ElementsIn(operand_shape);
auto elements_per_feature = auto elements_per_feature_literal =
computation_->AddInstruction(HloInstruction::CreateConstant( Literal::CreateR0<float>(size_in_elements / feature_count);
Literal::CreateR0<float>(size_in_elements / feature_count))); TF_ASSIGN_OR_RETURN(elements_per_feature_literal,
elements_per_feature_literal->Convert(ptype));
auto elements_per_feature = computation_->AddInstruction(
HloInstruction::CreateConstant(std::move(elements_per_feature_literal)));
HloInstruction* scale = batch_norm->mutable_operand(1); HloInstruction* scale = batch_norm->mutable_operand(1);
HloInstruction* offset = batch_norm->mutable_operand(2); HloInstruction* offset = batch_norm->mutable_operand(2);
const Shape feature_shape = scale->shape(); const Shape feature_shape = scale->shape();
auto zero_literal = Literal::CreateR0(0.0f);
TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype));
auto zero = computation_->AddInstruction( auto zero = computation_->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); HloInstruction::CreateConstant(std::move(zero_literal)));
auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon());
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
auto epsilon = computation_->AddInstruction( auto epsilon = computation_->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon()))); HloInstruction::CreateConstant(std::move(epsilon_literal)));
std::vector<int64> dimensions_without_feature; std::vector<int64> dimensions_without_feature;
@ -184,7 +192,7 @@ Status BatchNormRewriterVisitor::HandleBatchNormTraining(
HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index})); HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index}));
HloComputation* add_reduce_computation = HloComputation* add_reduce_computation =
GetScalarBinaryComputation(F32, HloOpcode::kAdd); GetScalarBinaryComputation(ptype, HloOpcode::kAdd);
// X^2. // X^2.
auto operand_squared = auto operand_squared =
@ -243,8 +251,10 @@ Status BatchNormRewriterVisitor::HandleBatchNormTraining(
computation_->AddInstruction(HloInstruction::CreateBinary( computation_->AddInstruction(HloInstruction::CreateBinary(
operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon)); operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon));
auto neg_half_literal = Literal::CreateR0(-0.5f);
TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype));
auto neg_half = computation_->AddInstruction( auto neg_half = computation_->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0(-0.5f))); HloInstruction::CreateConstant(std::move(neg_half_literal)));
// 1 / Sqrt[Var[X] + epsilon]. // 1 / Sqrt[Var[X] + epsilon].
auto rsqrt_var_add_epsilon = auto rsqrt_var_add_epsilon =
@ -286,6 +296,7 @@ Status BatchNormRewriterVisitor::HandleBatchNormInference(
HloInstruction* operand = batch_norm->mutable_operand(0); HloInstruction* operand = batch_norm->mutable_operand(0);
const Shape operand_shape = operand->shape(); const Shape operand_shape = operand->shape();
int64 feature_index = batch_norm->feature_index(); int64 feature_index = batch_norm->feature_index();
PrimitiveType ptype = operand_shape.element_type();
HloInstruction* scale = batch_norm->mutable_operand(1); HloInstruction* scale = batch_norm->mutable_operand(1);
HloInstruction* offset = batch_norm->mutable_operand(2); HloInstruction* offset = batch_norm->mutable_operand(2);
@ -293,8 +304,10 @@ Status BatchNormRewriterVisitor::HandleBatchNormInference(
HloInstruction* var = batch_norm->mutable_operand(4); HloInstruction* var = batch_norm->mutable_operand(4);
const Shape feature_shape = scale->shape(); const Shape feature_shape = scale->shape();
auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon());
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
auto epsilon = computation_->AddInstruction( auto epsilon = computation_->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon()))); HloInstruction::CreateConstant(std::move(epsilon_literal)));
std::vector<int64> dimensions_without_feature; std::vector<int64> dimensions_without_feature;
@ -321,8 +334,10 @@ Status BatchNormRewriterVisitor::HandleBatchNormInference(
computation_->AddInstruction(HloInstruction::CreateBinary( computation_->AddInstruction(HloInstruction::CreateBinary(
operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon)); operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon));
auto neg_half_literal = Literal::CreateR0(-0.5f);
TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype));
auto neg_half = computation_->AddInstruction( auto neg_half = computation_->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0(-0.5f))); HloInstruction::CreateConstant(std::move(neg_half_literal)));
// 1 / Sqrt[Var[X] + epsilon]. // 1 / Sqrt[Var[X] + epsilon].
auto rsqrt_var_add_epsilon = auto rsqrt_var_add_epsilon =
@ -373,6 +388,7 @@ Status BatchNormRewriterVisitor::HandleBatchNormGrad(
HloInstruction* activation = batch_norm->mutable_operand(0); HloInstruction* activation = batch_norm->mutable_operand(0);
const Shape activation_shape = activation->shape(); const Shape activation_shape = activation->shape();
PrimitiveType ptype = activation_shape.element_type();
HloInstruction* scale = batch_norm->mutable_operand(1); HloInstruction* scale = batch_norm->mutable_operand(1);
const Shape feature_shape = scale->shape(); const Shape feature_shape = scale->shape();
HloInstruction* mean = batch_norm->mutable_operand(2); HloInstruction* mean = batch_norm->mutable_operand(2);
@ -383,18 +399,27 @@ Status BatchNormRewriterVisitor::HandleBatchNormGrad(
const int64 size_in_elements = ShapeUtil::ElementsIn(activation_shape); const int64 size_in_elements = ShapeUtil::ElementsIn(activation_shape);
const int64 feature_count = activation_shape.dimensions(feature_index); const int64 feature_count = activation_shape.dimensions(feature_index);
auto elements_per_feature = auto elements_per_feature_literal =
computation_->AddInstruction(HloInstruction::CreateConstant( Literal::CreateR0<float>(size_in_elements / feature_count);
Literal::CreateR0<float>(size_in_elements / feature_count))); TF_ASSIGN_OR_RETURN(elements_per_feature_literal,
elements_per_feature_literal->Convert(ptype));
auto elements_per_feature = computation_->AddInstruction(
HloInstruction::CreateConstant(std::move(elements_per_feature_literal)));
auto zero_literal = Literal::CreateR0(0.0f);
TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype));
auto zero = computation_->AddInstruction( auto zero = computation_->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); HloInstruction::CreateConstant(std::move(zero_literal)));
auto neg_half_literal = Literal::CreateR0(-0.5f);
TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype));
auto neg_half = computation_->AddInstruction( auto neg_half = computation_->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0(-0.5f))); HloInstruction::CreateConstant(std::move(neg_half_literal)));
auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon());
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
auto epsilon = computation_->AddInstruction( auto epsilon = computation_->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon()))); HloInstruction::CreateConstant(std::move(epsilon_literal)));
std::vector<int64> dimensions_without_feature; std::vector<int64> dimensions_without_feature;
@ -442,7 +467,7 @@ Status BatchNormRewriterVisitor::HandleBatchNormGrad(
grad_output, activation_minus_mean)); grad_output, activation_minus_mean));
HloComputation* add_reduce_computation = HloComputation* add_reduce_computation =
GetScalarBinaryComputation(F32, HloOpcode::kAdd); GetScalarBinaryComputation(ptype, HloOpcode::kAdd);
// sum(Grad[Y] * (X - E[X])). // sum(Grad[Y] * (X - E[X])).
auto sum_grad_output_times_activiation_minus_mean = auto sum_grad_output_times_activiation_minus_mean =

View File

@ -197,28 +197,35 @@ void InitializeLLVMCommandLineOptions(const HloModuleConfig& config) {
class CollectProfileCandidates : public DfsHloVisitorWithDefault { class CollectProfileCandidates : public DfsHloVisitorWithDefault {
public: public:
static StatusOr<std::unordered_map<const HloInstruction*, size_t>> static StatusOr<std::unordered_map<const HloInstruction*, size_t>>
GetCandidatesForComputation(HloComputation* computation) { GetCandidatesForComputation(
HloComputation* computation,
const std::unordered_map<const HloInstruction*, int64>&
assigned_indices) {
std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx; std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx;
CollectProfileCandidates profile_candidates_for_computation( CollectProfileCandidates profile_candidates_for_computation(
&hlo_to_profile_idx); &hlo_to_profile_idx, assigned_indices);
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
computation->Accept(&profile_candidates_for_computation)); computation->Accept(&profile_candidates_for_computation));
return hlo_to_profile_idx; return hlo_to_profile_idx;
} }
private: private:
explicit CollectProfileCandidates( CollectProfileCandidates(
std::unordered_map<const HloInstruction*, size_t>* hlo_to_profile_idx) std::unordered_map<const HloInstruction*, size_t>* hlo_to_profile_idx,
: hlo_to_profile_idx_(hlo_to_profile_idx) {} const std::unordered_map<const HloInstruction*, int64>& assigned_indices)
: hlo_to_profile_idx_(hlo_to_profile_idx),
assigned_indices_(assigned_indices) {}
Status DefaultAction(HloInstruction* hlo_instruction) override { Status DefaultAction(HloInstruction* hlo_instruction) override {
hlo_to_profile_idx_->insert({hlo_instruction, hlo_to_profile_idx_->size()}); hlo_to_profile_idx_->insert(
{hlo_instruction, FindOrDie(assigned_indices_, hlo_instruction)});
return Status::OK(); return Status::OK();
} }
Status HandleCall(HloInstruction* call) override { Status HandleCall(HloInstruction* call) override {
TF_RETURN_IF_ERROR(DefaultAction(call)); TF_RETURN_IF_ERROR(DefaultAction(call));
CollectProfileCandidates candidates_for_call(hlo_to_profile_idx_); CollectProfileCandidates candidates_for_call(hlo_to_profile_idx_,
assigned_indices_);
TF_RETURN_IF_ERROR(call->to_apply()->Accept(&candidates_for_call)); TF_RETURN_IF_ERROR(call->to_apply()->Accept(&candidates_for_call));
return Status::OK(); return Status::OK();
} }
@ -232,17 +239,20 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault {
Status HandleWhile(HloInstruction* xla_while) override { Status HandleWhile(HloInstruction* xla_while) override {
TF_RETURN_IF_ERROR(DefaultAction(xla_while)); TF_RETURN_IF_ERROR(DefaultAction(xla_while));
CollectProfileCandidates candidates_for_condition(hlo_to_profile_idx_); CollectProfileCandidates candidates_for_condition(hlo_to_profile_idx_,
assigned_indices_);
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
xla_while->while_condition()->Accept(&candidates_for_condition)); xla_while->while_condition()->Accept(&candidates_for_condition));
CollectProfileCandidates candidates_for_body(hlo_to_profile_idx_); CollectProfileCandidates candidates_for_body(hlo_to_profile_idx_,
assigned_indices_);
TF_RETURN_IF_ERROR(xla_while->while_body()->Accept(&candidates_for_body)); TF_RETURN_IF_ERROR(xla_while->while_body()->Accept(&candidates_for_body));
return Status::OK(); return Status::OK();
} }
std::unordered_map<const HloInstruction*, size_t>* hlo_to_profile_idx_; std::unordered_map<const HloInstruction*, size_t>* hlo_to_profile_idx_;
const std::unordered_map<const HloInstruction*, int64>& assigned_indices_;
}; };
} // namespace } // namespace
@ -475,10 +485,27 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
HloComputation* computation = module->entry_computation(); HloComputation* computation = module->entry_computation();
std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx; std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx;
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map;
std::unique_ptr<HloProfilePrinter> hlo_profile_printer;
if (module->config().hlo_profiling_enabled()) { if (module->config().hlo_profiling_enabled()) {
hlo_profile_index_map = MakeUnique<HloProfileIndexMap>(*module);
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
hlo_to_profile_idx, hlo_to_profile_idx,
CollectProfileCandidates::GetCandidatesForComputation(computation)); CollectProfileCandidates::GetCandidatesForComputation(
computation, hlo_profile_index_map->instruction_to_profile_idx()));
auto shape_size_bytes = [](const Shape& shape) {
// On the cpu, opaques are pointers.
if (ShapeUtil::IsOpaque(shape)) {
return static_cast<int64>(sizeof(void*));
}
return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
};
HloCostAnalysis cost_analysis(shape_size_bytes);
hlo_profile_printer =
CreateHloProfilePrinter(*hlo_profile_index_map, cost_analysis);
} }
std::unique_ptr<Executable> cpu_executable; std::unique_ptr<Executable> cpu_executable;
@ -544,8 +571,16 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
parallel_computations.emplace(to_apply, instruction); parallel_computations.emplace(to_apply, instruction);
} }
// We always profile the entire computation as a whole, even if hlo
// profiling is disabled. When hlo profiling is diabled, we pass in a
// profile counter array of just one element, which corresponds to the whole
// computation.
size_t entry_computation_profile_idx =
hlo_profile_index_map ? hlo_profile_index_map->GetProfileIndexFor(
*module->entry_computation())
: 0;
IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), IrEmitter ir_emitter(*module, *assignment, llvm_module.get(),
hlo_to_profile_idx, hlo_to_profile_idx.size(), hlo_to_profile_idx, entry_computation_profile_idx,
jit->target_machine(), jit->external_constant_pool()); jit->target_machine(), jit->external_constant_pool());
std::unique_ptr<HloInstructionMap<string>> function_names( std::unique_ptr<HloInstructionMap<string>> function_names(
@ -586,8 +621,8 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
jit->AddModule(std::move(llvm_module)); jit->AddModule(std::move(llvm_module));
cpu_executable.reset(new ParallelCpuExecutable( cpu_executable.reset(new ParallelCpuExecutable(
std::move(jit), std::move(assignment), std::move(module), std::move(jit), std::move(assignment), std::move(module),
std::move(function_names), std::move(hlo_to_profile_idx), std::move(function_names), std::move(aligned_constants),
std::move(aligned_constants))); std::move(hlo_profile_printer), std::move(hlo_profile_index_map)));
if (embed_ir_in_executable) { if (embed_ir_in_executable) {
static_cast<CpuExecutable&>(*cpu_executable) static_cast<CpuExecutable&>(*cpu_executable)
@ -620,12 +655,22 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory(
proto, xla_dump_hlo_proto_to, module->name())); proto, xla_dump_hlo_proto_to, module->name()));
} }
// We always profile the entire computation as a whole, even if hlo
// profiling is disabled. When hlo profiling is diabled, we pass in a
// profile counter array of just one element, which corresponds to the whole
// computation.
size_t entry_computation_profile_idx =
hlo_profile_index_map ? hlo_profile_index_map->GetProfileIndexFor(
*module->entry_computation())
: 0;
// Each computation is a single function. Emit all embedded computations // Each computation is a single function. Emit all embedded computations
// before the entry computation. The order of computations returned from // before the entry computation. The order of computations returned from
// GetEmbeddedComputations guarantees that a called computation occurs // GetEmbeddedComputations guarantees that a called computation occurs
// before a caller computation. // before a caller computation.
IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), IrEmitter ir_emitter(*module, *assignment, llvm_module.get(),
hlo_to_profile_idx, hlo_to_profile_idx.size(), hlo_to_profile_idx, entry_computation_profile_idx,
jit->target_machine(), jit->external_constant_pool()); jit->target_machine(), jit->external_constant_pool());
for (auto embedded_computation : for (auto embedded_computation :
@ -659,7 +704,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
jit->AddModule(std::move(llvm_module)); jit->AddModule(std::move(llvm_module));
cpu_executable.reset(new CpuExecutable( cpu_executable.reset(new CpuExecutable(
std::move(jit), std::move(assignment), std::move(module), function_name, std::move(jit), std::move(assignment), std::move(module), function_name,
std::move(hlo_to_profile_idx))); std::move(hlo_profile_printer), std::move(hlo_profile_index_map)));
if (embed_ir_in_executable) { if (embed_ir_in_executable) {
static_cast<CpuExecutable&>(*cpu_executable) static_cast<CpuExecutable&>(*cpu_executable)

View File

@ -43,6 +43,7 @@ limitations under the License.
#include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
#include "tensorflow/stream_executor/host/host_stream.h"
namespace se = ::perftools::gputools; namespace se = ::perftools::gputools;
@ -54,11 +55,12 @@ CpuExecutable::CpuExecutable(
std::unique_ptr<const BufferAssignment> assignment, std::unique_ptr<const BufferAssignment> assignment,
std::unique_ptr<const HloModule> hlo_module, std::unique_ptr<const HloModule> hlo_module,
const string& entry_function_name, const string& entry_function_name,
std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx) std::unique_ptr<HloProfilePrinter> hlo_profile_printer,
: Executable(std::move(hlo_module)), std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map)
: Executable(std::move(hlo_module), std::move(hlo_profile_printer),
std::move(hlo_profile_index_map)),
jit_(std::move(jit)), jit_(std::move(jit)),
assignment_(std::move(assignment)), assignment_(std::move(assignment)) {
hlo_to_profile_idx_(std::move(hlo_to_profile_idx)) {
// Resolve symbols in the constructor rather than at execution time to avoid // Resolve symbols in the constructor rather than at execution time to avoid
// races because FindSymbol is not thread safe. // races because FindSymbol is not thread safe.
llvm::JITSymbol sym = jit_->FindSymbol(entry_function_name); llvm::JITSymbol sym = jit_->FindSymbol(entry_function_name);
@ -182,9 +184,16 @@ Status CpuExecutable::ExecuteComputeFunction(
uint64 start_micros = tensorflow::Env::Default()->NowMicros(); uint64 start_micros = tensorflow::Env::Default()->NowMicros();
// Allocate profiling counters for each hlo instruction that we would like to // Allocate profiling counters for each hlo instruction that we would like to
// profile. Allocate an additional profile counter for the entire // profile. Even when not Hlo profiling, we allocate a counter for the entire
// computation. // computation, which we use to update ExecutionProfile below.
std::vector<uint64> profile_counters(hlo_to_profile_idx_.size() + 1); std::vector<int64>* profile_counters = nullptr;
std::vector<int64> profile_counter_for_entry_computation;
if (hlo_execution_profile) {
profile_counters = hlo_execution_profile->mutable_profile_counters();
} else {
profile_counters = &profile_counter_for_entry_computation;
profile_counter_for_entry_computation.push_back(0);
}
// Call the computation function following the calling convention. // Call the computation function following the calling convention.
std::vector<void*> buffer_pointers; std::vector<void*> buffer_pointers;
@ -199,7 +208,7 @@ Status CpuExecutable::ExecuteComputeFunction(
VLOG(3) << tensorflow::strings::Printf( VLOG(3) << tensorflow::strings::Printf(
" func(void* result, void* params[%zu], void* temps[%zu], " " func(void* result, void* params[%zu], void* temps[%zu], "
"uint64 profile_counters[%zu])", "uint64 profile_counters[%zu])",
args_array.size(), buffer_pointers.size(), profile_counters.size()); args_array.size(), buffer_pointers.size(), profile_counters->size());
VLOG(3) << tensorflow::strings::Printf(" result = %p", result_buffer); VLOG(3) << tensorflow::strings::Printf(" result = %p", result_buffer);
auto ptr_printer = [](string* out, const void* p) { auto ptr_printer = [](string* out, const void* p) {
tensorflow::strings::StrAppend(out, tensorflow::strings::Printf("%p", p)); tensorflow::strings::StrAppend(out, tensorflow::strings::Printf("%p", p));
@ -211,11 +220,11 @@ Status CpuExecutable::ExecuteComputeFunction(
" temps = [%s]", " temps = [%s]",
tensorflow::str_util::Join(buffer_pointers, ", ", ptr_printer).c_str()); tensorflow::str_util::Join(buffer_pointers, ", ", ptr_printer).c_str());
VLOG(3) << tensorflow::strings::Printf(" profile_counters = %p", VLOG(3) << tensorflow::strings::Printf(" profile_counters = %p",
profile_counters.data()); profile_counters->data());
} }
compute_function_(result_buffer, run_options, args_array.data(), compute_function_(result_buffer, run_options, args_array.data(),
buffer_pointers.data(), profile_counters.data()); buffer_pointers.data(), profile_counters->data());
uint64 end_micros = tensorflow::Env::Default()->NowMicros(); uint64 end_micros = tensorflow::Env::Default()->NowMicros();
@ -224,20 +233,46 @@ Status CpuExecutable::ExecuteComputeFunction(
const double nanoseconds = (end_micros - start_micros) * 1000.0; const double nanoseconds = (end_micros - start_micros) * 1000.0;
execution_profile_.set_compute_time_ns(std::max(nanoseconds, 1.0)); execution_profile_.set_compute_time_ns(std::max(nanoseconds, 1.0));
// The last profile counter is used for the computation as a whole. if (hlo_execution_profile) {
execution_profile_.set_compute_cycle_count(profile_counters.back()); execution_profile_.set_compute_cycle_count(
hlo_execution_profile->total_cycles_executed(
*module().entry_computation()));
} else {
execution_profile_.set_compute_cycle_count(profile_counters->back());
}
} }
if (hlo_execution_profile != nullptr) { return Status::OK();
hlo_execution_profile->set_total_cycles_executed( }
*module().entry_computation(), profile_counters.back());
for (auto hlo_prof_idx : hlo_to_profile_idx_) { static void LogLiveAddresses(
const HloInstruction* hlo = hlo_prof_idx.first; const std::unordered_set<const void*>& marked_addresses) {
uint64 cycles_taken = profile_counters[hlo_prof_idx.second]; VLOG(3) << "Live addresses in output marking found "
hlo_execution_profile->SetCyclesTakenBy(hlo, cycles_taken); << marked_addresses.size() << " addresses:\n"
<< tensorflow::str_util::Join(
marked_addresses, ", ", [](string* out, const void* address) {
tensorflow::strings::StrAppend(
out, tensorflow::strings::Printf("%p", address));
});
}
static Status DeallocateTempBuffers(
DeviceMemoryAllocator* allocator, se::Stream* stream,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers,
const std::unordered_set<const void*>& marked_addresses) {
// Keep those marked live because they are referenced by the output of the
// computation and are needed by the service. They will be deallocated by the
// service.
for (size_t i = 0; i < buffers.size(); ++i) {
se::DeviceMemoryBase alloc = buffers[i];
if (marked_addresses.count(alloc.opaque()) == 0 && !alloc.is_null()) {
VLOG(3) << "CpuExecutable deallocating buffer #" << i << " ["
<< alloc.opaque() << "]";
TF_RETURN_IF_ERROR(
allocator->Deallocate(stream->parent()->device_ordinal(), &alloc));
} }
} }
return Status::OK(); return Status::OK();
} }
@ -263,26 +298,9 @@ StatusOr<perftools::gputools::DeviceMemoryBase> CpuExecutable::ExecuteOnStream(
MarkLiveAddressesInOutput(top_level_output.opaque(), result_shape(), MarkLiveAddressesInOutput(top_level_output.opaque(), result_shape(),
&marked_addresses); &marked_addresses);
VLOG(3) << "Live addresses in output marking found " LogLiveAddresses(marked_addresses);
<< marked_addresses.size() << " addresses:\n" TF_RETURN_IF_ERROR(DeallocateTempBuffers(memory_allocator, stream, buffers,
<< tensorflow::str_util::Join( marked_addresses));
marked_addresses, ", ", [](string* out, const void* address) {
tensorflow::strings::StrAppend(
out, tensorflow::strings::Printf("%p", address));
});
// Computation is done - deallocate temp buffers. Keep those marked live
// because they are referenced by the output of the computation and are needed
// by the service. They will be deallocated by the service.
for (size_t i = 0; i < buffers.size(); ++i) {
se::DeviceMemoryBase alloc = buffers[i];
if (marked_addresses.count(alloc.opaque()) == 0 && !alloc.is_null()) {
VLOG(3) << "CpuExecutable deallocating buffer #" << i << " ["
<< alloc.opaque() << "]";
TF_RETURN_IF_ERROR(memory_allocator->Deallocate(
stream->parent()->device_ordinal(), &alloc));
}
}
return top_level_output; return top_level_output;
} }
@ -360,9 +378,44 @@ StatusOr<perftools::gputools::DeviceMemoryBase>
CpuExecutable::ExecuteAsyncOnStream( CpuExecutable::ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options, const ServiceExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments) { tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments) {
// TODO(b/30671675): Implement asynchronous execution mode. if (hlo_profiling_enabled()) {
return Unimplemented( return Unimplemented(
"Asynchronous execution on stream is not yet supported on CPU."); "Asynchronous execution on stream with hlo profiling is not yet "
"supported on CPU.");
}
auto* host_stream = dynamic_cast<perftools::gputools::host::HostStream*>(
run_options->stream()->implementation());
se::Stream* stream = run_options->stream();
DeviceMemoryAllocator* memory_allocator = run_options->allocator();
std::vector<se::DeviceMemoryBase> buffers(assignment_->Allocations().size());
TF_RETURN_IF_ERROR(AllocateBuffers(
memory_allocator, stream->parent()->device_ordinal(), &buffers));
// Mark the buffers that are actually live (used in the output) when the
// computation finishes executing.
std::unordered_set<const void*> marked_addresses;
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice,
assignment_->GetUniqueTopLevelOutputSlice());
se::DeviceMemoryBase top_level_output = buffers[result_slice.index()];
MarkLiveAddressesInOutput(top_level_output.opaque(), result_shape(),
&marked_addresses);
LogLiveAddresses(marked_addresses);
host_stream->EnqueueTask([this, run_options, arguments, buffers,
marked_addresses, memory_allocator, stream]() {
// Failing a CHECK here is not great, but I don't see an obvious way to
// return a failed Status asynchronously.
TF_CHECK_OK(ExecuteComputeFunction(&run_options->run_options(), arguments,
buffers,
/*hlo_execution_profile=*/nullptr));
TF_CHECK_OK(DeallocateTempBuffers(memory_allocator, stream, buffers,
marked_addresses));
});
return top_level_output;
} }
/*static*/ int64 CpuExecutable::ShapeSizeBytes(const Shape& shape) { /*static*/ int64 CpuExecutable::ShapeSizeBytes(const Shape& shape) {
@ -378,9 +431,5 @@ const PointsToSet& CpuExecutable::GetRootPointsToSet() const {
module().entry_computation()->root_instruction()); module().entry_computation()->root_instruction());
} }
std::unique_ptr<HloCostAnalysis> CpuExecutable::CreateCostAnalysis() const {
return MakeUnique<HloCostAnalysis>(ShapeSizeBytes);
}
} // namespace cpu } // namespace cpu
} // namespace xla } // namespace xla

View File

@ -47,12 +47,12 @@ namespace cpu {
// architecture, so JIT-ed code and host code share the same ABI. // architecture, so JIT-ed code and host code share the same ABI.
class CpuExecutable : public Executable { class CpuExecutable : public Executable {
public: public:
CpuExecutable( CpuExecutable(std::unique_ptr<SimpleOrcJIT> jit,
std::unique_ptr<SimpleOrcJIT> jit,
std::unique_ptr<const BufferAssignment> assignment, std::unique_ptr<const BufferAssignment> assignment,
std::unique_ptr<const HloModule> hlo_module, std::unique_ptr<const HloModule> hlo_module,
const string& entry_function_name, const string& entry_function_name,
std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx); std::unique_ptr<HloProfilePrinter> hlo_profile_printer,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map);
~CpuExecutable() override {} ~CpuExecutable() override {}
StatusOr<perftools::gputools::DeviceMemoryBase> ExecuteOnStream( StatusOr<perftools::gputools::DeviceMemoryBase> ExecuteOnStream(
@ -85,12 +85,10 @@ class CpuExecutable : public Executable {
static int64 ShapeSizeBytes(const Shape& shape); static int64 ShapeSizeBytes(const Shape& shape);
std::unique_ptr<HloCostAnalysis> CreateCostAnalysis() const override;
// Type of the computation function we expect in the JIT. // Type of the computation function we expect in the JIT.
using ComputeFunctionType = void (*)( using ComputeFunctionType = void (*)(
void* /*result*/, const ExecutableRunOptions* /*run_options*/, void* /*result*/, const ExecutableRunOptions* /*run_options*/,
const void** /*args*/, void** /*temps*/, uint64* /*profile_counters*/); const void** /*args*/, void** /*temps*/, int64* /*profile_counters*/);
const ComputeFunctionType& compute_function() const { const ComputeFunctionType& compute_function() const {
return compute_function_; return compute_function_;
@ -145,9 +143,6 @@ class CpuExecutable : public Executable {
// Entry function name for the computation. // Entry function name for the computation.
const string entry_function_name_; const string entry_function_name_;
// Maps HLOs to their index into the profile counter array.
const std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx_;
TF_DISALLOW_COPY_AND_ASSIGN(CpuExecutable); TF_DISALLOW_COPY_AND_ASSIGN(CpuExecutable);
}; };

View File

@ -59,19 +59,20 @@ ParallelCpuExecutable::ParallelCpuExecutable(
std::unique_ptr<const BufferAssignment> assignment, std::unique_ptr<const BufferAssignment> assignment,
std::unique_ptr<const HloModule> hlo_module, std::unique_ptr<const HloModule> hlo_module,
std::unique_ptr<const HloInstructionMap<string>> function_names, std::unique_ptr<const HloInstructionMap<string>> function_names,
std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx,
std::unordered_map<const HloInstruction*, std::unique_ptr<unsigned char[]>> std::unordered_map<const HloInstruction*, std::unique_ptr<unsigned char[]>>
aligned_constants) aligned_constants,
: Executable(std::move(hlo_module)), std::unique_ptr<HloProfilePrinter> hlo_profile_printer,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map)
: Executable(std::move(hlo_module), std::move(hlo_profile_printer),
std::move(hlo_profile_index_map)),
jit_(std::move(jit)), jit_(std::move(jit)),
assignment_(std::move(assignment)), assignment_(std::move(assignment)),
function_names_(std::move(function_names)), function_names_(std::move(function_names)),
hlo_to_profile_idx_(std::move(hlo_to_profile_idx)),
aligned_constants_(std::move(aligned_constants)) {} aligned_constants_(std::move(aligned_constants)) {}
// Type of the computation function we expect in the JIT. // Type of the computation function we expect in the JIT.
using ComputeFunctionType = void (*)(void*, const void*, const void**, void**, using ComputeFunctionType = void (*)(void*, const void*, const void**, void**,
int64*, uint64*); int64*, int64*);
// Given a pointer to an output buffer (following the CPU JIT calling // Given a pointer to an output buffer (following the CPU JIT calling
// conventions), mark addresses that are "live". The initial pointer itself is // conventions), mark addresses that are "live". The initial pointer itself is
@ -106,7 +107,7 @@ class Executor {
const ServiceExecutableRunOptions* run_options, const ServiceExecutableRunOptions* run_options,
std::list<HloInstruction*>* pending, std::list<HloInstruction*>* pending,
HloInstructionMap<const void*>* results, void** temps_array, HloInstructionMap<const void*>* results, void** temps_array,
uint64* profile_counters_array, const BufferAssignment* assignment) int64* profile_counters_array, const BufferAssignment* assignment)
: functions_(functions), : functions_(functions),
run_options_(run_options), run_options_(run_options),
pending_(pending), pending_(pending),
@ -147,7 +148,7 @@ class Executor {
std::list<HloInstruction*>* pending_; std::list<HloInstruction*>* pending_;
HloInstructionMap<const void*>* results_; HloInstructionMap<const void*>* results_;
void** temps_array_; void** temps_array_;
uint64* profile_counters_array_; int64* profile_counters_array_;
tensorflow::thread::ThreadPool* thread_pool_; tensorflow::thread::ThreadPool* thread_pool_;
const BufferAssignment* assignment_; const BufferAssignment* assignment_;
@ -389,9 +390,11 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions(
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers, tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers,
HloExecutionProfile* hlo_execution_profile) { HloExecutionProfile* hlo_execution_profile) {
// Allocate profiling counters for each hlo instruction that we would like to // Allocate profiling counters for each hlo instruction that we would like to
// profile. Allocate an additional profile counter for the entire // profile.
// computation. std::vector<int64>* profile_counters = nullptr;
std::vector<uint64> profile_counters(hlo_to_profile_idx_.size() + 1); if (hlo_execution_profile) {
profile_counters = hlo_execution_profile->mutable_profile_counters();
}
std::vector<void*> buffer_pointers; std::vector<void*> buffer_pointers;
buffer_pointers.reserve(buffers.size()); buffer_pointers.reserve(buffers.size());
@ -441,9 +444,9 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions(
// For example, if we expect a library conv/matmul call to run at max // For example, if we expect a library conv/matmul call to run at max
// concurrency, we should not dispatch runnable instructions until the // concurrency, we should not dispatch runnable instructions until the
// library call is finished (to avoid expensive cache invalidation). // library call is finished (to avoid expensive cache invalidation).
Executor executor(functions, run_options, &pending, &results, Executor executor(
buffer_pointers.data(), profile_counters.data(), functions, run_options, &pending, &results, buffer_pointers.data(),
assignment_.get()); profile_counters ? profile_counters->data() : nullptr, assignment_.get());
TF_RETURN_IF_ERROR(executor.Run()); TF_RETURN_IF_ERROR(executor.Run());
@ -453,18 +456,6 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions(
tensorflow::mutex_lock lock(mutex_); tensorflow::mutex_lock lock(mutex_);
double nanoseconds = (end_micros - start_micros) * 1000.0; double nanoseconds = (end_micros - start_micros) * 1000.0;
execution_profile_.set_compute_time_ns(std::max(nanoseconds, 1.0)); execution_profile_.set_compute_time_ns(std::max(nanoseconds, 1.0));
// The last profile counter is used for the computation as a whole.
execution_profile_.set_compute_cycle_count(profile_counters.back());
}
if (hlo_execution_profile != nullptr) {
hlo_execution_profile->set_total_cycles_executed(entry_computation,
profile_counters.back());
for (auto hlo_prof_idx : hlo_to_profile_idx_) {
const HloInstruction* hlo = hlo_prof_idx.first;
uint64 cycles_taken = profile_counters[hlo_prof_idx.second];
hlo_execution_profile->SetCyclesTakenBy(hlo, cycles_taken);
}
} }
return Status::OK(); return Status::OK();
@ -618,10 +609,5 @@ const PointsToSet& ParallelCpuExecutable::GetRootPointsToSet() const {
module().entry_computation()->root_instruction()); module().entry_computation()->root_instruction());
} }
std::unique_ptr<HloCostAnalysis> ParallelCpuExecutable::CreateCostAnalysis()
const {
return MakeUnique<HloCostAnalysis>(ShapeSizeBytes);
}
} // namespace cpu } // namespace cpu
} // namespace xla } // namespace xla

View File

@ -52,10 +52,11 @@ class ParallelCpuExecutable : public Executable {
std::unique_ptr<const BufferAssignment> assignment, std::unique_ptr<const BufferAssignment> assignment,
std::unique_ptr<const HloModule> hlo_module, std::unique_ptr<const HloModule> hlo_module,
std::unique_ptr<const HloInstructionMap<string>> function_names, std::unique_ptr<const HloInstructionMap<string>> function_names,
std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx,
std::unordered_map<const HloInstruction*, std::unordered_map<const HloInstruction*,
std::unique_ptr<unsigned char[]>> std::unique_ptr<unsigned char[]>>
aligned_constants); aligned_constants,
std::unique_ptr<HloProfilePrinter> hlo_profile_printer,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map);
~ParallelCpuExecutable() override {} ~ParallelCpuExecutable() override {}
StatusOr<perftools::gputools::DeviceMemoryBase> ExecuteOnStream( StatusOr<perftools::gputools::DeviceMemoryBase> ExecuteOnStream(
@ -95,8 +96,6 @@ class ParallelCpuExecutable : public Executable {
"Equality test on CPU parallel executable is not implemented."); "Equality test on CPU parallel executable is not implemented.");
} }
std::unique_ptr<HloCostAnalysis> CreateCostAnalysis() const override;
private: private:
// Allocate buffers required for execution and assign them to the elements of // Allocate buffers required for execution and assign them to the elements of
// "buffers". "buffers" should be sized to the number of buffers in buffer // "buffers". "buffers" should be sized to the number of buffers in buffer
@ -143,9 +142,6 @@ class ParallelCpuExecutable : public Executable {
// Map containing the JITted function names for each HLO instruction. // Map containing the JITted function names for each HLO instruction.
const std::unique_ptr<const HloInstructionMap<string>> function_names_; const std::unique_ptr<const HloInstructionMap<string>> function_names_;
// Maps HLOs to their index into the profile counter array.
const std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx_;
// Map from HLO Constant instructions to a pointer to their literal data. // Map from HLO Constant instructions to a pointer to their literal data.
// The data stored in the protocol buffer might be insufficiently aligned, // The data stored in the protocol buffer might be insufficiently aligned,
// we create a sufficiently aligned copy and store it in this map. // we create a sufficiently aligned copy and store it in this map.

View File

@ -44,8 +44,15 @@ namespace xla {
// interface that is used for launching compiled programs across platforms. // interface that is used for launching compiled programs across platforms.
class Executable { class Executable {
public: public:
explicit Executable(std::unique_ptr<const HloModule> hlo_module) explicit Executable(std::unique_ptr<const HloModule> hlo_module,
: hlo_module_(std::move(hlo_module)) {} std::unique_ptr<HloProfilePrinter> hlo_profile_printer,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map)
: hlo_module_(std::move(hlo_module)),
hlo_profile_printer_(std::move(hlo_profile_printer)),
hlo_profile_index_map_(std::move(hlo_profile_index_map)) {
CHECK_EQ(hlo_profile_printer_.get() == nullptr,
hlo_profile_index_map_.get() == nullptr);
}
virtual ~Executable() {} virtual ~Executable() {}
// Enqueues the compilation result on the provided stream, passing the given // Enqueues the compilation result on the provided stream, passing the given
@ -123,12 +130,20 @@ class Executable {
"Equality test on this executable is not implemented."); "Equality test on this executable is not implemented.");
} }
const HloProfilePrinter& hlo_profile_printer() const {
CHECK(hlo_profiling_enabled());
return *hlo_profile_printer_;
}
const HloProfileIndexMap& hlo_profile_index_map() const {
CHECK(hlo_profiling_enabled());
return *hlo_profile_index_map_;
}
// Returns whether this executable was compiled with HLO profilings support // Returns whether this executable was compiled with HLO profilings support
// enabled. If not, the caller should not expect an hlo_execution_profile // enabled. If not, the caller should not expect an hlo_execution_profile
// passed to ExecuteOnStream above to be populated during execution. // passed to ExecuteOnStream above to be populated during execution.
bool hlo_profiling_enabled() const { bool hlo_profiling_enabled() const { return hlo_profile_printer_ != nullptr; }
return hlo_module_->config().hlo_profiling_enabled();
}
const HloModule& module() const { return *hlo_module_; } const HloModule& module() const { return *hlo_module_; }
@ -160,10 +175,6 @@ class Executable {
static Status DumpToDirectory(const string& directory_path, string filename, static Status DumpToDirectory(const string& directory_path, string filename,
const SessionModule& session_module); const SessionModule& session_module);
// Returns a cost analysis object appropriate for the platform on which this
// executable can run.
virtual std::unique_ptr<HloCostAnalysis> CreateCostAnalysis() const = 0;
protected: protected:
mutable tensorflow::mutex mutex_; mutable tensorflow::mutex mutex_;
@ -181,6 +192,9 @@ class Executable {
// Execution count, used to generate a unique filename for each dumped // Execution count, used to generate a unique filename for each dumped
// execution. // execution.
int64 execution_count_ = 0; int64 execution_count_ = 0;
std::unique_ptr<HloProfilePrinter> hlo_profile_printer_;
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map_;
}; };
template <typename ReturnT, typename ArgT> template <typename ReturnT, typename ArgT>
@ -200,7 +214,8 @@ StatusOr<ReturnT> Executable::ExecuteOnStreamWrapper(
std::unique_ptr<HloExecutionProfile> profile_ptr = std::unique_ptr<HloExecutionProfile> profile_ptr =
module_config().debug_options().xla_hlo_profile() && module_config().debug_options().xla_hlo_profile() &&
hlo_profiling_enabled() hlo_profiling_enabled()
? MakeUnique<HloExecutionProfile>(module(), *CreateCostAnalysis()) ? MakeUnique<HloExecutionProfile>(&hlo_profile_printer(),
&hlo_profile_index_map())
: nullptr; : nullptr;
auto return_value = auto return_value =

View File

@ -465,10 +465,20 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
VLOG(2) << "Printing the thunk schedule..."; VLOG(2) << "Printing the thunk schedule...";
XLA_VLOG_LINES(2, thunk_schedule->ToString()); XLA_VLOG_LINES(2, thunk_schedule->ToString());
auto* gpu_executable = std::unique_ptr<HloProfileIndexMap> profile_index_map;
new GpuExecutable(ptx, cubin, {cc_major, cc_minor}, std::unique_ptr<HloProfilePrinter> profile_printer;
std::move(thunk_schedule), std::move(module),
std::move(buffer_assignment), ShapeSizeBytesFunction()); if (module->config().hlo_profiling_enabled()) {
HloCostAnalysis cost_analysis(ShapeSizeBytesFunction());
profile_index_map = MakeUnique<HloProfileIndexMap>(*module);
profile_printer =
CreateHloProfilePrinter(*profile_index_map, cost_analysis);
}
auto* gpu_executable = new GpuExecutable(
ptx, cubin, {cc_major, cc_minor}, std::move(thunk_schedule),
std::move(module), std::move(buffer_assignment),
std::move(profile_printer), std::move(profile_index_map));
if (embed_ir_in_executable) { if (embed_ir_in_executable) {
DCHECK_NE("", ir_module_string_before_opt); DCHECK_NE("", ir_module_string_before_opt);
gpu_executable->set_ir_module_string(ir_module_string_before_opt); gpu_executable->set_ir_module_string(ir_module_string_before_opt);

View File

@ -113,14 +113,15 @@ GpuExecutable::GpuExecutable(
std::unique_ptr<const ThunkSchedule> thunk_schedule, std::unique_ptr<const ThunkSchedule> thunk_schedule,
std::unique_ptr<const HloModule> hlo_module, std::unique_ptr<const HloModule> hlo_module,
std::unique_ptr<const BufferAssignment> assignment, std::unique_ptr<const BufferAssignment> assignment,
HloCostAnalysis::ShapeSizeFunction shape_size_function) std::unique_ptr<HloProfilePrinter> hlo_profile_printer,
: Executable(std::move(hlo_module)), std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map)
: Executable(std::move(hlo_module), std::move(hlo_profile_printer),
std::move(hlo_profile_index_map)),
ptx_(ptx), ptx_(ptx),
cubin_(cubin), cubin_(cubin),
compute_capability_(compute_capability), compute_capability_(compute_capability),
thunk_schedule_(std::move(thunk_schedule)), thunk_schedule_(std::move(thunk_schedule)),
assignment_(std::move(assignment)), assignment_(std::move(assignment)) {}
shape_size_function_(std::move(shape_size_function)) {}
Status GpuExecutable::ExecuteThunks( Status GpuExecutable::ExecuteThunks(
const ServiceExecutableRunOptions* run_options, const ServiceExecutableRunOptions* run_options,
@ -358,9 +359,5 @@ const PointsToSet& GpuExecutable::GetRootPointsToSet() const {
module().entry_computation()->root_instruction()); module().entry_computation()->root_instruction());
} }
std::unique_ptr<HloCostAnalysis> GpuExecutable::CreateCostAnalysis() const {
return MakeUnique<HloCostAnalysis>(shape_size_function_);
}
} // namespace gpu } // namespace gpu
} // namespace xla } // namespace xla

View File

@ -54,7 +54,8 @@ class GpuExecutable : public Executable {
std::unique_ptr<const ThunkSchedule> thunk_schedule, std::unique_ptr<const ThunkSchedule> thunk_schedule,
std::unique_ptr<const HloModule> hlo_module, std::unique_ptr<const HloModule> hlo_module,
std::unique_ptr<const BufferAssignment> assignment, std::unique_ptr<const BufferAssignment> assignment,
HloCostAnalysis::ShapeSizeFunction shape_size_function); std::unique_ptr<HloProfilePrinter> hlo_profile_printer,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map);
// This should be called after set_ir_module_string. // This should be called after set_ir_module_string.
const string& ir_module_string() const { return ir_module_string_; } const string& ir_module_string() const { return ir_module_string_; }
@ -95,8 +96,6 @@ class GpuExecutable : public Executable {
return Unimplemented("Equality test on GPU executable is not implemented."); return Unimplemented("Equality test on GPU executable is not implemented.");
} }
std::unique_ptr<HloCostAnalysis> CreateCostAnalysis() const override;
private: private:
// If `block_host_until_done` is false, execution will not block the host // If `block_host_until_done` is false, execution will not block the host
// until the kernels have completed. This is used as an optimization for // until the kernels have completed. This is used as an optimization for
@ -140,9 +139,6 @@ class GpuExecutable : public Executable {
// memory for every output/temp buffers. // memory for every output/temp buffers.
const std::unique_ptr<const BufferAssignment> assignment_; const std::unique_ptr<const BufferAssignment> assignment_;
// Function to compute the size of a given Shape, in bytes.
const HloCostAnalysis::ShapeSizeFunction shape_size_function_;
TF_DISALLOW_COPY_AND_ASSIGN(GpuExecutable); TF_DISALLOW_COPY_AND_ASSIGN(GpuExecutable);
}; };

View File

@ -40,7 +40,7 @@ HloProfileIndexMap::HloProfileIndexMap(const HloModule& module) {
} }
} }
static HloProfilePrinter CreateOwnedHloProfilePrinter( std::unique_ptr<HloProfilePrinter> CreateHloProfilePrinter(
const HloProfileIndexMap& hlo_profile_index_map, const HloProfileIndexMap& hlo_profile_index_map,
const HloCostAnalysis& cost_analysis) { const HloCostAnalysis& cost_analysis) {
using HloComputationInfo = HloProfilePrinter::HloComputationInfo; using HloComputationInfo = HloProfilePrinter::HloComputationInfo;
@ -108,15 +108,15 @@ static HloProfilePrinter CreateOwnedHloProfilePrinter(
delete[] computation_infos; delete[] computation_infos;
}; };
return HloProfilePrinter(computation_infos, return MakeUnique<HloProfilePrinter>(
hlo_profile_index_map.computation_count(), deleter); computation_infos, hlo_profile_index_map.computation_count(), deleter);
} }
HloExecutionProfile::HloExecutionProfile(const HloModule& module, HloExecutionProfile::HloExecutionProfile(
const HloCostAnalysis& cost_analysis) const HloProfilePrinter* hlo_profile_printer,
: hlo_profile_index_map_(module), const HloProfileIndexMap* hlo_profile_index_map)
hlo_profile_printer_( : hlo_profile_printer_(*hlo_profile_printer),
CreateOwnedHloProfilePrinter(hlo_profile_index_map_, cost_analysis)), hlo_profile_index_map_(*hlo_profile_index_map),
profile_counters_( profile_counters_(
/*count*/ hlo_profile_index_map_.total_count(), /*count*/ hlo_profile_index_map_.total_count(),
/*value*/ 0) {} /*value*/ 0) {}
@ -131,10 +131,4 @@ uint64 HloExecutionProfile::GetCyclesTakenBy(const HloInstruction& hlo) const {
return profile_counters_[hlo_profile_index_map_.GetProfileIndexFor(hlo)]; return profile_counters_[hlo_profile_index_map_.GetProfileIndexFor(hlo)];
} }
string HloExecutionProfile::ToString(
const DeviceDescription& device_description) const {
return hlo_profile_printer_.ToString(profile_counters_.data(),
device_description.clock_rate_ghz());
}
} // namespace xla } // namespace xla

View File

@ -77,6 +77,11 @@ class HloProfileIndexMap {
std::unordered_map<const HloComputation*, int64> computation_to_profile_idx_; std::unordered_map<const HloComputation*, int64> computation_to_profile_idx_;
}; };
// Create an instance of `HloProfilePrinter` that owns its memory.
std::unique_ptr<HloProfilePrinter> CreateHloProfilePrinter(
const HloProfileIndexMap& hlo_profile_index_map,
const HloCostAnalysis& cost_analysis);
// Describes how much time each HLO operation took. // Describes how much time each HLO operation took.
// //
// Each HloComputation takes a certain number of cycles. This class helps break // Each HloComputation takes a certain number of cycles. This class helps break
@ -85,8 +90,8 @@ class HloExecutionProfile {
public: public:
using DeviceDescription = perftools::gputools::DeviceDescription; using DeviceDescription = perftools::gputools::DeviceDescription;
HloExecutionProfile(const HloModule& module, HloExecutionProfile(const HloProfilePrinter* hlo_profile_printer,
const HloCostAnalysis& cost_analysis); const HloProfileIndexMap* hlo_profile_index_map);
// Record how many cycles this HLO took to execute. // Record how many cycles this HLO took to execute.
void SetCyclesTakenBy(const HloInstruction* hlo, uint64 cycles_taken); void SetCyclesTakenBy(const HloInstruction* hlo, uint64 cycles_taken);
@ -114,15 +119,16 @@ class HloExecutionProfile {
// for the operations in a given computation. Returns an empty string if it // for the operations in a given computation. Returns an empty string if it
// wasn't possible to generate a printable version. cost_analysis should be a // wasn't possible to generate a printable version. cost_analysis should be a
// clean analysis that can be used to visit the computation. // clean analysis that can be used to visit the computation.
string ToString(const DeviceDescription& device_description) const; string ToString(const DeviceDescription& device_description) const {
return hlo_profile_printer_.ToString(profile_counters_.data(),
device_description.clock_rate_ghz());
}
std::vector<int64>* mutable_profile_counters() { return &profile_counters_; }
private: private:
// hlo_profile_index_map_ maps an Hlo entity (computation or instruction) to const HloProfilePrinter& hlo_profile_printer_;
// an index in profile_counters_. const HloProfileIndexMap& hlo_profile_index_map_;
HloProfileIndexMap hlo_profile_index_map_;
// Used to print profile_counters_ in a human readable form.
HloProfilePrinter hlo_profile_printer_;
// Stores per-Hlo profile counters. This is the only thing that changes when // Stores per-Hlo profile counters. This is the only thing that changes when
// we execute an XLA computation. // we execute an XLA computation.

View File

@ -72,7 +72,11 @@ TEST_F(HloExecutionProfileTest, Basic) {
}; };
HloCostAnalysis cost_analysis(shape_size_function); HloCostAnalysis cost_analysis(shape_size_function);
HloExecutionProfile execution_profile(*hlo_module, cost_analysis); HloProfileIndexMap profile_index_map(*hlo_module);
std::unique_ptr<HloProfilePrinter> profile_printer =
CreateHloProfilePrinter(profile_index_map, cost_analysis);
HloExecutionProfile execution_profile(profile_printer.get(),
&profile_index_map);
const int64 add_cycles = 1000; const int64 add_cycles = 1000;
const int64 dot_cycles = 4000; const int64 dot_cycles = 4000;

View File

@ -42,7 +42,8 @@ namespace sep = ::perftools::gputools::interpreter;
InterpreterExecutable::InterpreterExecutable( InterpreterExecutable::InterpreterExecutable(
std::unique_ptr<const HloModule> hlo_module) std::unique_ptr<const HloModule> hlo_module)
: Executable(std::move(hlo_module)) {} : Executable(std::move(hlo_module), /*hlo_profile_printer=*/nullptr,
/*hlo_profile_index_map=*/nullptr) {}
InterpreterExecutable::~InterpreterExecutable() {} InterpreterExecutable::~InterpreterExecutable() {}
@ -156,10 +157,5 @@ StatusOr<se::DeviceMemoryBase> InterpreterExecutable::ExecuteAsyncOnStream(
return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
} }
std::unique_ptr<HloCostAnalysis> InterpreterExecutable::CreateCostAnalysis()
const {
return MakeUnique<HloCostAnalysis>(ShapeSizeBytes);
}
} // namespace interpreter } // namespace interpreter
} // namespace xla } // namespace xla

View File

@ -61,8 +61,6 @@ class InterpreterExecutable : public Executable {
static int64 ShapeSizeBytes(const Shape& shape); static int64 ShapeSizeBytes(const Shape& shape);
std::unique_ptr<HloCostAnalysis> CreateCostAnalysis() const override;
private: private:
TF_DISALLOW_COPY_AND_ASSIGN(InterpreterExecutable); TF_DISALLOW_COPY_AND_ASSIGN(InterpreterExecutable);
}; };

View File

@ -575,12 +575,13 @@ Service::ExecuteParallelAndRegisterResult(
// profile. // profile.
for (auto& index_to_profiled_stream : index_to_profiled_streams) { for (auto& index_to_profiled_stream : index_to_profiled_streams) {
int64 device = index_to_profiled_stream.first; int64 device = index_to_profiled_stream.first;
auto& module = executables[device]->module();
se::Stream* stream = index_to_profiled_stream.second; se::Stream* stream = index_to_profiled_stream.second;
HloExecutionProfile hlo_profile(module, Executable* executable = executables[device];
*executables[device]->CreateCostAnalysis()); const HloModule& module = executable->module();
TF_RETURN_IF_ERROR(executables[device]->PopulateExecutionProfile( HloExecutionProfile hlo_profile(&executable->hlo_profile_printer(),
&hlo_profile, stream->parent())); &executable->hlo_profile_index_map());
TF_RETURN_IF_ERROR(
executable->PopulateExecutionProfile(&hlo_profile, stream->parent()));
XLA_LOG_LINES( XLA_LOG_LINES(
tensorflow::INFO, tensorflow::INFO,
hlo_profile.ToString(streams[0]->parent()->GetDeviceDescription())); hlo_profile.ToString(streams[0]->parent()->GetDeviceDescription()));

View File

@ -773,6 +773,11 @@ xla_test(
xla_test( xla_test(
name = "bfloat16_test", name = "bfloat16_test",
srcs = ["bfloat16_test.cc"], srcs = ["bfloat16_test.cc"],
blacklisted_backends = [
"cpu",
"cpu_parallel",
"gpu",
],
shard_count = 40, shard_count = 40,
deps = [ deps = [
":test_utils", ":test_utils",
@ -1343,6 +1348,7 @@ xla_test(
srcs = ["client_test.cc"], srcs = ["client_test.cc"],
deps = [ deps = [
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_data_proto",

View File

@ -51,8 +51,7 @@ class Bfloat16Test : public ClientLibraryTestBase {
const ErrorSpec error_spec_{0.001, 0.001}; const ErrorSpec error_spec_{0.001, 0.001};
}; };
XLA_TEST_F(Bfloat16Test, DISABLED_ON_GPU(DISABLED_ON_CPU_PARALLEL( XLA_TEST_F(Bfloat16Test, ScalarOperation) {
DISABLED_ON_CPU(ScalarOperation)))) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto x = builder.ConstantR0<bfloat16>(static_cast<bfloat16>(2.0f)); auto x = builder.ConstantR0<bfloat16>(static_cast<bfloat16>(2.0f));
auto y = builder.ConstantR0<bfloat16>(static_cast<bfloat16>(1.0f)); auto y = builder.ConstantR0<bfloat16>(static_cast<bfloat16>(1.0f));
@ -62,8 +61,7 @@ XLA_TEST_F(Bfloat16Test, DISABLED_ON_GPU(DISABLED_ON_CPU_PARALLEL(
error_spec_); error_spec_);
} }
XLA_TEST_F(Bfloat16Test, DISABLED_ON_GPU(DISABLED_ON_CPU_PARALLEL( XLA_TEST_F(Bfloat16Test, NegateScalarF16) {
DISABLED_ON_CPU(NegateScalarF16)))) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
builder.Neg(builder.ConstantR0<bfloat16>(static_cast<bfloat16>(2.1f))); builder.Neg(builder.ConstantR0<bfloat16>(static_cast<bfloat16>(2.1f)));
@ -71,5 +69,83 @@ XLA_TEST_F(Bfloat16Test, DISABLED_ON_GPU(DISABLED_ON_CPU_PARALLEL(
error_spec_); error_spec_);
} }
XLA_TEST_F(Bfloat16Test, BatchNormTraining) {
const int kFeatureIndex = 2;
ComputationBuilder builder(client_, TestName());
auto operand = builder.ConstantR4FromArray4D<bfloat16>(
{{{{static_cast<bfloat16>(1.f)}, {static_cast<bfloat16>(2.f)}},
{{static_cast<bfloat16>(3.f)}, {static_cast<bfloat16>(4.f)}}},
{{{static_cast<bfloat16>(5.f)}, {static_cast<bfloat16>(6.f)}},
{{static_cast<bfloat16>(7.f)}, {static_cast<bfloat16>(8.f)}}}});
auto scale = builder.ConstantR1<bfloat16>(
{static_cast<bfloat16>(2.0f), static_cast<bfloat16>(3.0f)});
auto offset = builder.ConstantR1<bfloat16>(
{static_cast<bfloat16>(1.0f), static_cast<bfloat16>(2.0f)});
auto tuple = builder.BatchNormTraining(operand, scale, offset,
/*epsilon=*/0.001, kFeatureIndex);
auto expected = *Literal::MakeTuple(
{Literal::CreateR4<bfloat16>(
{{{{static_cast<bfloat16>(-1.7f)}, {static_cast<bfloat16>(-2.04f)}},
{{static_cast<bfloat16>(0.105f)}, {static_cast<bfloat16>(0.65f)}}},
{{{static_cast<bfloat16>(1.89f)}, {static_cast<bfloat16>(3.35f)}},
{{static_cast<bfloat16>(3.7f)}, {static_cast<bfloat16>(6.04f)}}}})
.get(),
Literal::CreateR1<bfloat16>(
{static_cast<bfloat16>(4), static_cast<bfloat16>(5)})
.get(),
Literal::CreateR1<bfloat16>(
{static_cast<bfloat16>(5), static_cast<bfloat16>(5)})
.get()});
ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01));
}
XLA_TEST_F(Bfloat16Test, BatchNormGrad) {
const int kFeatureIndex = 2;
ComputationBuilder builder(client_, TestName());
auto operand = builder.ConstantR4FromArray4D<bfloat16>(
Array4D<bfloat16>(2, 2, 2, 1, static_cast<bfloat16>(0.0f)));
auto scale = builder.ConstantR1<bfloat16>(
{static_cast<bfloat16>(1.0f), static_cast<bfloat16>(1.0f)});
auto mean = builder.ConstantR1<bfloat16>(
{static_cast<bfloat16>(0.0f), static_cast<bfloat16>(0.0f)});
auto var = builder.ConstantR1<bfloat16>(
{static_cast<bfloat16>(1.0f), static_cast<bfloat16>(1.0f)});
auto grad_output = builder.ConstantR4FromArray4D<bfloat16>(
{{{{static_cast<bfloat16>(1.f)}, {static_cast<bfloat16>(2.f)}},
{{static_cast<bfloat16>(3.f)}, {static_cast<bfloat16>(4.f)}}},
{{{static_cast<bfloat16>(5.f)}, {static_cast<bfloat16>(6.f)}},
{{static_cast<bfloat16>(7.f)}, {static_cast<bfloat16>(8.f)}}}});
builder.BatchNormGrad(operand, scale, mean, var, grad_output,
/*epsilon=*/0.0, kFeatureIndex);
auto expected = *Literal::MakeTuple(
{Literal::CreateR4<bfloat16>(
{{{{static_cast<bfloat16>(-3.f)}, {static_cast<bfloat16>(-3.f)}},
{{static_cast<bfloat16>(-1.f)}, {static_cast<bfloat16>(-1.f)}}},
{{{static_cast<bfloat16>(1.f)}, {static_cast<bfloat16>(1.f)}},
{{static_cast<bfloat16>(3.f)}, {static_cast<bfloat16>(3.f)}}}})
.get(),
Literal::CreateR1<bfloat16>(
{static_cast<bfloat16>(0), static_cast<bfloat16>(0)})
.get(),
Literal::CreateR1<bfloat16>(
{static_cast<bfloat16>(16), static_cast<bfloat16>(20)})
.get()});
ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01));
}
} // namespace } // namespace
} // namespace xla } // namespace xla

View File

@ -29,6 +29,7 @@ def xla_test(name,
deps, deps,
xla_test_library_deps=[], xla_test_library_deps=[],
backends=[], backends=[],
blacklisted_backends=[],
args=[], args=[],
tags=[], tags=[],
copts=[], copts=[],
@ -92,17 +93,24 @@ def xla_test(name,
backends: A list of backends to generate tests for. Supported backends: A list of backends to generate tests for. Supported
values: "cpu", "cpu_parallel", "gpu". If this list is empty, the test will values: "cpu", "cpu_parallel", "gpu". If this list is empty, the test will
be generated for all supported backends. be generated for all supported backends.
blacklisted_backends: A list of backends to NOT generate tests for.
args: Test arguments for the target. args: Test arguments for the target.
tags: Tags for the target. tags: Tags for the target.
backend_args: A dict mapping backend name to list of additional args to copts: Additional copts to pass to the build.
use for that target. data: Additional data to pass to the build.
backend_tags: A dict mapping backend name to list of additional tags to backend_tags: A dict mapping backend name to list of additional tags to
use for that target. use for that target.
backend_args: A dict mapping backend name to list of additional args to
use for that target.
**kwargs: Additional keyword arguments to pass to native.cc_test.
""" """
test_names = [] test_names = []
if not backends: if not backends:
backends = all_backends backends = all_backends
backends = [backend for backend in backends
if backend not in blacklisted_backends]
native.cc_library( native.cc_library(
name="%s_lib" % name, name="%s_lib" % name,
srcs=srcs, srcs=srcs,

View File

@ -20,10 +20,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
@ -42,26 +44,26 @@ TEST_F(ClientTest, ExecuteWithLayout) {
for (const std::vector<int64>& transfer_layout : layouts) { for (const std::vector<int64>& transfer_layout : layouts) {
b.Add(b.ConstantR2<int32>({{1, 2}, {3, 4}}), b.Add(b.ConstantR2<int32>({{1, 2}, {3, 4}}),
b.ConstantR2<int32>({{10, 20}, {30, 40}})); b.ConstantR2<int32>({{10, 20}, {30, 40}}));
auto computation = b.Build(); TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build());
ASSERT_TRUE(computation.ok()) << computation.status();
ExecutionOptions execution_options = execution_options_; ExecutionOptions execution_options = execution_options_;
*execution_options.mutable_shape_with_output_layout() = *execution_options.mutable_shape_with_output_layout() =
ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2}, ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2},
execute_layout); execute_layout);
std::unique_ptr<GlobalData> data = TF_ASSERT_OK_AND_ASSIGN(
client_->Execute(computation.ValueOrDie(), {}, &execution_options) std::unique_ptr<GlobalData> data,
.ConsumeValueOrDie(); client_->Execute(computation, {}, &execution_options));
std::unique_ptr<Literal> expected_literal = std::unique_ptr<Literal> expected_literal =
Literal::CreateR2WithLayout<int32>( Literal::CreateR2WithLayout<int32>(
{{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout)); {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout));
auto computed = client_->Transfer(*data, &expected_literal->shape()); TF_ASSERT_OK_AND_ASSIGN(
auto computed, client_->Transfer(*data, &expected_literal->shape()));
LiteralTestUtil::AssertEqualShapesAndLayouts( LiteralTestUtil::AssertEqualShapesAndLayouts(expected_literal->shape(),
expected_literal->shape(), computed.ValueOrDie()->shape()); computed->shape());
LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie()); LiteralTestUtil::ExpectEqual(*expected_literal, *computed);
} }
} }
} }
@ -72,8 +74,7 @@ TEST_F(ClientTest, ExecuteWithTupleLayout) {
b.Tuple({b.ConstantR2<int32>({{1, 2}, {3, 4}}), b.Tuple({b.ConstantR2<int32>({{1, 2}, {3, 4}}),
b.ConstantR2<int32>({{10, 20}, {30, 40}})}); b.ConstantR2<int32>({{10, 20}, {30, 40}})});
auto computation = b.Build(); TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build());
ASSERT_TRUE(computation.ok()) << computation.status();
ExecutionOptions execution_options = execution_options_; ExecutionOptions execution_options = execution_options_;
// Create a result shape with one element column major and the other row // Create a result shape with one element column major and the other row
@ -85,10 +86,9 @@ TEST_F(ClientTest, ExecuteWithTupleLayout) {
ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2}, ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2},
/*minor_to_major=*/{1, 0})}); /*minor_to_major=*/{1, 0})});
auto result = TF_ASSERT_OK_AND_ASSIGN(
client_ auto result,
->ExecuteAndTransfer(computation.ValueOrDie(), {}, &execution_options) client_->ExecuteAndTransfer(computation, {}, &execution_options));
.ConsumeValueOrDie();
LiteralTestUtil::ExpectR2Equal<int32>({{1, 2}, {3, 4}}, LiteralTestUtil::ExpectR2Equal<int32>({{1, 2}, {3, 4}},
result->tuple_literals(0)); result->tuple_literals(0));
LiteralTestUtil::ExpectR2Equal<int32>({{10, 20}, {30, 40}}, LiteralTestUtil::ExpectR2Equal<int32>({{10, 20}, {30, 40}},
@ -107,5 +107,42 @@ TEST_F(ClientTest, ExecuteWithTupleLayout) {
/*minor_to_major=*/{1, 0}))); /*minor_to_major=*/{1, 0})));
} }
TEST_F(ClientTest, DISABLED_ON_CPU_PARALLEL(DISABLED_ON_GPU(ExecuteParallel))) {
Computation add_with_one_arg, mul_with_two_args, dot_with_one_arg;
Shape shape = ShapeUtil::MakeShape(S32, {2, 2});
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> const_arg,
client_->TransferToServer(*Literal::CreateR2<int32>({{5, 6}, {7, 8}})));
ComputationBuilder b(client_, TestName() + ".add");
b.Add(b.Parameter(0, shape, "param_0"),
b.ConstantR2<int32>({{1, 2}, {3, 4}}));
TF_ASSERT_OK_AND_ASSIGN(add_with_one_arg, b.Build());
// We can't really test parallel execution on CPU since all of the cores in a
// CPU are presented as a single device. So for now we test "parallel"
// execution on a single device.
std::vector<Client::ComputationInstance> computation_instances;
TF_ASSERT_OK_AND_ASSIGN(std::vector<xla::DeviceHandle> devices,
client_->GetDeviceHandles(1));
ASSERT_EQ(devices.size(), 1);
ExecutionOptions options = execution_options_;
*options.add_device_handles() = devices[0];
computation_instances.push_back(Client::ComputationInstance(
add_with_one_arg, {const_arg.get()}, options, nullptr));
TF_ASSERT_OK_AND_ASSIGN(auto results,
client_->ExecuteParallel(computation_instances));
auto expected_result = Literal::CreateR2<int32>({{6, 8}, {10, 12}});
TF_ASSERT_OK_AND_ASSIGN(
auto result_literal,
client_->Transfer(*results[0], &expected_result->shape()));
LiteralTestUtil::ExpectEqual(*expected_result, *result_literal);
}
} // namespace } // namespace
} // namespace xla } // namespace xla

View File

@ -37,7 +37,7 @@ set_target_properties(lib_tf PROPERTIES IMPORTED_LOCATION
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DIS_SLIM_BUILD \ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DIS_SLIM_BUILD \
-std=c++11 -fno-rtti -fno-exceptions \ -std=c++11 -fno-rtti -fno-exceptions \
-O2 -Wno-narrowing -fomit-frame-pointer \ -O2 -Wno-narrowing -fomit-frame-pointer \
-mfpu=neon -mfloat-abi=softfp -fPIE \ -mfpu=neon -mfloat-abi=softfp -fPIE -fPIC \
-ftemplate-depth=900 \ -ftemplate-depth=900 \
-DGOOGLE_PROTOBUF_NO_RTTI \ -DGOOGLE_PROTOBUF_NO_RTTI \
-DGOOGLE_PROTOBUF_NO_STATIC_INITIALIZER") -DGOOGLE_PROTOBUF_NO_STATIC_INITIALIZER")

View File

@ -16,7 +16,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/node-stats.h" #include "tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h"
#include "tensorflow/contrib/boosted_trees/proto/split_info.pb.h" #include "tensorflow/contrib/boosted_trees/proto/split_info.pb.h"
#include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h" #include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h"
#include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/device_base.h"

View File

@ -408,7 +408,7 @@ tf_cc_test(
# Learner/stochastic # Learner/stochastic
cc_library( cc_library(
name = "gradient-stats", name = "gradient-stats",
hdrs = ["learner/stochastic/stats/gradient-stats.h"], hdrs = ["learner/common/stats/gradient-stats.h"],
deps = [ deps = [
"//tensorflow/core:framework_headers_lib", "//tensorflow/core:framework_headers_lib",
"//third_party/eigen3", "//third_party/eigen3",
@ -417,7 +417,7 @@ cc_library(
cc_library( cc_library(
name = "node-stats", name = "node-stats",
hdrs = ["learner/stochastic/stats/node-stats.h"], hdrs = ["learner/common/stats/node-stats.h"],
deps = [ deps = [
":gradient-stats", ":gradient-stats",
"//tensorflow/contrib/boosted_trees/proto:learner_proto_cc", "//tensorflow/contrib/boosted_trees/proto:learner_proto_cc",
@ -429,7 +429,7 @@ cc_library(
cc_library( cc_library(
name = "split-stats", name = "split-stats",
hdrs = ["learner/stochastic/stats/split-stats.h"], hdrs = ["learner/common/stats/split-stats.h"],
deps = [ deps = [
":node-stats", ":node-stats",
], ],
@ -437,7 +437,7 @@ cc_library(
cc_library( cc_library(
name = "feature-split-candidate", name = "feature-split-candidate",
hdrs = ["learner/stochastic/stats/feature-split-candidate.h"], hdrs = ["learner/common/stats/feature-split-candidate.h"],
deps = [ deps = [
":split-stats", ":split-stats",
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc", "//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc",
@ -447,7 +447,7 @@ cc_library(
tf_cc_test( tf_cc_test(
name = "node-stats_test", name = "node-stats_test",
size = "small", size = "small",
srcs = ["learner/stochastic/stats/node-stats_test.cc"], srcs = ["learner/common/stats/node-stats_test.cc"],
deps = [ deps = [
":node-stats", ":node-stats",
"//tensorflow/core:tensor_testutil", "//tensorflow/core:tensor_testutil",

View File

@ -13,10 +13,10 @@
// limitations under the License. // limitations under the License.
// //
// ============================================================================= // =============================================================================
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_FEATURE_SPLIT_CANDIDATE_H_ #ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_FEATURE_SPLIT_CANDIDATE_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_FEATURE_SPLIT_CANDIDATE_H_ #define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_FEATURE_SPLIT_CANDIDATE_H_
#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/split-stats.h" #include "tensorflow/contrib/boosted_trees/lib/learner/common/stats/split-stats.h"
#include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h" #include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h"
namespace tensorflow { namespace tensorflow {
@ -58,4 +58,4 @@ struct FeatureSplitCandidate {
} // namespace boosted_trees } // namespace boosted_trees
} // namespace tensorflow } // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_FEATURE_SPLIT_CANDIDATE_H_ #endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_FEATURE_SPLIT_CANDIDATE_H_

View File

@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// ============================================================================= // =============================================================================
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_GRADIENT_STATS_H_ #ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_GRADIENT_STATS_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_GRADIENT_STATS_H_ #define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_GRADIENT_STATS_H_
#include <math.h> #include <math.h>
@ -190,4 +190,4 @@ inline GradientStats operator-(const GradientStats& a, const GradientStats& b) {
} // namespace boosted_trees } // namespace boosted_trees
} // namespace tensorflow } // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_GRADIENT_STATS_H_ #endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_GRADIENT_STATS_H_

View File

@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// ============================================================================= // =============================================================================
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_NODE_STATS_H_ #ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_NODE_STATS_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_NODE_STATS_H_ #define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_NODE_STATS_H_
#include "third_party/eigen3/Eigen/Core" #include "third_party/eigen3/Eigen/Core"
#include "third_party/eigen3/Eigen/Eigenvalues" #include "third_party/eigen3/Eigen/Eigenvalues"
#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/gradient-stats.h" #include "tensorflow/contrib/boosted_trees/lib/learner/common/stats/gradient-stats.h"
#include "tensorflow/contrib/boosted_trees/proto/learner.pb.h" #include "tensorflow/contrib/boosted_trees/proto/learner.pb.h"
#include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h" #include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h"
#include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/shape_inference.h"
@ -298,4 +298,4 @@ struct NodeStats {
} // namespace boosted_trees } // namespace boosted_trees
} // namespace tensorflow } // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_NODE_STATS_H_ #endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_NODE_STATS_H_

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// ============================================================================= // =============================================================================
#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/node-stats.h" #include "tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h"
#include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"

View File

@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// ============================================================================= // =============================================================================
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_SPLIT_STATS_H_ #ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_SPLIT_STATS_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_SPLIT_STATS_H_ #define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_SPLIT_STATS_H_
#include <string> #include <string>
#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/node-stats.h" #include "tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h"
namespace tensorflow { namespace tensorflow {
namespace boosted_trees { namespace boosted_trees {
@ -81,4 +81,4 @@ struct SplitStats {
} // namespace boosted_trees } // namespace boosted_trees
} // namespace tensorflow } // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_SPLIT_STATS_H_ #endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_SPLIT_STATS_H_

View File

@ -32,12 +32,26 @@ from tensorflow.python.platform import test
class CrfTest(test.TestCase): class CrfTest(test.TestCase):
def testCrfSequenceScore(self): def testCrfSequenceScore(self):
inputs = np.array(
[[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32)
tag_indices = np.array([1, 2, 1, 0], dtype=np.int32)
transition_params = np.array( transition_params = np.array(
[[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32) [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
sequence_lengths = np.array(3, dtype=np.int32) # Test both the length-1 and regular cases.
sequence_lengths_list = [
np.array(3, dtype=np.int32),
np.array(1, dtype=np.int32)
]
inputs_list = [
np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]],
dtype=np.float32),
np.array([[4, 5, -3]],
dtype=np.float32),
]
tag_indices_list = [
np.array([1, 2, 1, 0], dtype=np.int32),
np.array([1], dtype=np.int32)
]
for sequence_lengths, inputs, tag_indices in zip(sequence_lengths_list,
inputs_list,
tag_indices_list):
with self.test_session() as sess: with self.test_session() as sess:
sequence_score = crf.crf_sequence_score( sequence_score = crf.crf_sequence_score(
inputs=array_ops.expand_dims(inputs, 0), inputs=array_ops.expand_dims(inputs, 0),
@ -89,13 +103,29 @@ class CrfTest(test.TestCase):
self.assertAllClose(tf_binary_score, expected_binary_score) self.assertAllClose(tf_binary_score, expected_binary_score)
def testCrfLogNorm(self): def testCrfLogNorm(self):
inputs = np.array(
[[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32)
transition_params = np.array( transition_params = np.array(
[[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32) [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
# Test both the length-1 and regular cases.
sequence_lengths_list = [
np.array(3, dtype=np.int32),
np.array(1, dtype=np.int32)
]
inputs_list = [
np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]],
dtype=np.float32),
np.array([[3, -1, 3]],
dtype=np.float32),
]
tag_indices_list = [
np.array([1, 2, 1, 0], dtype=np.int32),
np.array([2], dtype=np.int32)
]
for sequence_lengths, inputs, tag_indices in zip(sequence_lengths_list,
inputs_list,
tag_indices_list):
num_words = inputs.shape[0] num_words = inputs.shape[0]
num_tags = inputs.shape[1] num_tags = inputs.shape[1]
sequence_lengths = np.array(3, dtype=np.int32)
with self.test_session() as sess: with self.test_session() as sess:
all_sequence_scores = [] all_sequence_scores = []
@ -201,11 +231,27 @@ class CrfTest(test.TestCase):
expected_max_sequence[:sequence_lengths]) expected_max_sequence[:sequence_lengths])
def testCrfDecode(self): def testCrfDecode(self):
inputs = np.array(
[[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32)
transition_params = np.array( transition_params = np.array(
[[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32) [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
sequence_lengths = np.array(3, dtype=np.int32) # Test both the length-1 and regular cases.
sequence_lengths_list = [
np.array(3, dtype=np.int32),
np.array(1, dtype=np.int32)
]
inputs_list = [
np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]],
dtype=np.float32),
np.array([[-1, 2, 1]],
dtype=np.float32),
]
tag_indices_list = [
np.array([1, 2, 1, 0], dtype=np.int32),
np.array([2], dtype=np.int32)
]
for sequence_lengths, inputs, tag_indices in zip(sequence_lengths_list,
inputs_list,
tag_indices_list):
num_words = inputs.shape[0] num_words = inputs.shape[0]
num_tags = inputs.shape[1] num_tags = inputs.shape[1]

View File

@ -53,7 +53,9 @@ from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.layers import utils
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn
@ -101,6 +103,17 @@ def crf_sequence_score(inputs, tag_indices, sequence_lengths,
Returns: Returns:
sequence_scores: A [batch_size] vector of unnormalized sequence scores. sequence_scores: A [batch_size] vector of unnormalized sequence scores.
""" """
# If max_seq_len is 1, we skip the score calculation and simply gather the
# unary potentials of the single tag.
def _single_seq_fn():
batch_size = array_ops.shape(inputs, out_type=tag_indices.dtype)[0]
example_inds = array_ops.reshape(
math_ops.range(batch_size, dtype=tag_indices.dtype), [-1, 1])
return array_ops.gather_nd(
array_ops.squeeze(inputs, [1]),
array_ops.concat([example_inds, tag_indices], axis=1))
def _multi_seq_fn():
# Compute the scores of the given tag sequence. # Compute the scores of the given tag sequence.
unary_scores = crf_unary_score(tag_indices, sequence_lengths, inputs) unary_scores = crf_unary_score(tag_indices, sequence_lengths, inputs)
binary_scores = crf_binary_score(tag_indices, sequence_lengths, binary_scores = crf_binary_score(tag_indices, sequence_lengths,
@ -108,6 +121,12 @@ def crf_sequence_score(inputs, tag_indices, sequence_lengths,
sequence_scores = unary_scores + binary_scores sequence_scores = unary_scores + binary_scores
return sequence_scores return sequence_scores
return utils.smart_cond(
pred=math_ops.equal(inputs.shape[1].value or array_ops.shape(inputs)[1],
1),
fn1=_single_seq_fn,
fn2=_multi_seq_fn)
def crf_log_norm(inputs, sequence_lengths, transition_params): def crf_log_norm(inputs, sequence_lengths, transition_params):
"""Computes the normalization for a CRF. """Computes the normalization for a CRF.
@ -124,6 +143,14 @@ def crf_log_norm(inputs, sequence_lengths, transition_params):
# algorithm. # algorithm.
first_input = array_ops.slice(inputs, [0, 0, 0], [-1, 1, -1]) first_input = array_ops.slice(inputs, [0, 0, 0], [-1, 1, -1])
first_input = array_ops.squeeze(first_input, [1]) first_input = array_ops.squeeze(first_input, [1])
# If max_seq_len is 1, we skip the algorithm and simply reduce_logsumexp over
# the "initial state" (the unary potentials).
def _single_seq_fn():
return math_ops.reduce_logsumexp(first_input, [1])
def _multi_seq_fn():
"""Forward computation of alpha values."""
rest_of_input = array_ops.slice(inputs, [0, 1, 0], [-1, -1, -1]) rest_of_input = array_ops.slice(inputs, [0, 1, 0], [-1, -1, -1])
# Compute the alpha values in the forward algorithm in order to get the # Compute the alpha values in the forward algorithm in order to get the
@ -138,6 +165,11 @@ def crf_log_norm(inputs, sequence_lengths, transition_params):
log_norm = math_ops.reduce_logsumexp(alphas, [1]) log_norm = math_ops.reduce_logsumexp(alphas, [1])
return log_norm return log_norm
max_seq_len = array_ops.shape(inputs)[1]
return control_flow_ops.cond(pred=math_ops.equal(max_seq_len, 1),
true_fn=_single_seq_fn,
false_fn=_multi_seq_fn)
def crf_log_likelihood(inputs, def crf_log_likelihood(inputs,
tag_indices, tag_indices,
@ -437,10 +469,22 @@ def crf_decode(potentials, transition_params, sequence_length):
sequence_length: A [batch_size] vector of true sequence lengths. sequence_length: A [batch_size] vector of true sequence lengths.
Returns: Returns:
decode_tags: A [batch_size, max_seq_len] tensor, with dtype tf.int32. decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`.
Contains the highest scoring tag indices. Contains the highest scoring tag indices.
best_score: A [batch_size] tensor, containing the score of decode_tags. best_score: A [batch_size] vector, containing the score of `decode_tags`.
""" """
# If max_seq_len is 1, we skip the algorithm and simply return the argmax tag
# and the max activation.
def _single_seq_fn():
squeezed_potentials = array_ops.squeeze(potentials, [1])
decode_tags = array_ops.expand_dims(
math_ops.argmax(squeezed_potentials, axis=1), 1)
best_score = math_ops.reduce_max(squeezed_potentials, axis=1)
return math_ops.cast(decode_tags, dtype=dtypes.int32), best_score
def _multi_seq_fn():
"""Decoding of highest scoring sequence."""
# For simplicity, in shape comments, denote: # For simplicity, in shape comments, denote:
# 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output). # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output).
num_tags = potentials.get_shape()[2].value num_tags = potentials.get_shape()[2].value
@ -450,32 +494,39 @@ def crf_decode(potentials, transition_params, sequence_length):
initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1]) initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1])
initial_state = array_ops.squeeze(initial_state, axis=[1]) # [B, O] initial_state = array_ops.squeeze(initial_state, axis=[1]) # [B, O]
inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O] inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O]
backpointers, last_score = rnn.dynamic_rnn( backpointers, last_score = rnn.dynamic_rnn( # [B, T - 1, O], [B, O]
crf_fwd_cell, crf_fwd_cell,
inputs=inputs, inputs=inputs,
sequence_length=sequence_length - 1, sequence_length=sequence_length - 1,
initial_state=initial_state, initial_state=initial_state,
time_major=False, time_major=False,
dtype=dtypes.int32) # [B, T - 1, O], [B, O] dtype=dtypes.int32)
backpointers = gen_array_ops.reverse_sequence( backpointers = gen_array_ops.reverse_sequence( # [B, T - 1, O]
backpointers, sequence_length - 1, seq_dim=1) # [B, T-1, O] backpointers, sequence_length - 1, seq_dim=1)
# Computes backward decoding. Extract tag indices from backpointers. # Computes backward decoding. Extract tag indices from backpointers.
crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags) crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags)
initial_state = math_ops.cast(math_ops.argmax(last_score, axis=1), initial_state = math_ops.cast(math_ops.argmax(last_score, axis=1), # [B]
dtype=dtypes.int32) # [B] dtype=dtypes.int32)
initial_state = array_ops.expand_dims(initial_state, axis=-1) # [B, 1] initial_state = array_ops.expand_dims(initial_state, axis=-1) # [B, 1]
decode_tags, _ = rnn.dynamic_rnn( decode_tags, _ = rnn.dynamic_rnn( # [B, T - 1, 1]
crf_bwd_cell, crf_bwd_cell,
inputs=backpointers, inputs=backpointers,
sequence_length=sequence_length - 1, sequence_length=sequence_length - 1,
initial_state=initial_state, initial_state=initial_state,
time_major=False, time_major=False,
dtype=dtypes.int32) # [B, T - 1, 1] dtype=dtypes.int32)
decode_tags = array_ops.squeeze(decode_tags, axis=[2]) # [B, T - 1] decode_tags = array_ops.squeeze(decode_tags, axis=[2]) # [B, T - 1]
decode_tags = array_ops.concat([initial_state, decode_tags], axis=1) # [B, T] decode_tags = array_ops.concat([initial_state, decode_tags], # [B, T]
decode_tags = gen_array_ops.reverse_sequence( axis=1)
decode_tags, sequence_length, seq_dim=1) # [B, T] decode_tags = gen_array_ops.reverse_sequence( # [B, T]
decode_tags, sequence_length, seq_dim=1)
best_score = math_ops.reduce_max(last_score, axis=1) # [B] best_score = math_ops.reduce_max(last_score, axis=1) # [B]
return decode_tags, best_score return decode_tags, best_score
return utils.smart_cond(
pred=math_ops.equal(
potentials.shape[1].value or array_ops.shape(potentials)[1], 1),
fn1=_single_seq_fn,
fn2=_multi_seq_fn)

View File

@ -187,6 +187,7 @@ py_test(
"manual", # b/67958761 "manual", # b/67958761
], ],
deps = [ deps = [
":dataset_serialization_test",
"//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/contrib/data/python/ops:dataset_ops",
"//tensorflow/contrib/data/python/ops:transformation_ops", "//tensorflow/contrib/data/python/ops:transformation_ops",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",

View File

@ -723,5 +723,41 @@ class BatchDatasetSerializationTest(
num_outputs) num_outputs)
class PaddedBatchDatasetSerializationTest(
dataset_serialization_test_base.DatasetSerializationTestBase):
def testPaddedBatch(self):
def build_dataset(seq_lens):
return dataset_ops.Dataset.from_tensor_slices(seq_lens).map(
lambda x: array_ops.fill([x], x)).padded_batch(
4, padded_shapes=[-1])
seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32)
seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32)
self.run_core_tests(lambda: build_dataset(seq_lens1),
lambda: build_dataset(seq_lens2), 8)
def testPaddedBatchNonDefaultPadding(self):
def build_dataset(seq_lens):
def fill_tuple(x):
filled = array_ops.fill([x], x)
return (filled, string_ops.as_string(filled))
padded_shape = [-1]
return dataset_ops.Dataset.from_tensor_slices(seq_lens).map(
fill_tuple).padded_batch(
4,
padded_shapes=(padded_shape, padded_shape),
padding_values=(-1, "<end>"))
seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32)
seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32)
self.run_core_tests(lambda: build_dataset(seq_lens1),
lambda: build_dataset(seq_lens2), 8)
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()

View File

@ -22,8 +22,10 @@ import math
import threading import threading
import time import time
import numpy as np
from six.moves import zip_longest from six.moves import zip_longest
from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.contrib.data.python.ops import dataset_ops
from tensorflow.contrib.data.python.ops import interleave_ops from tensorflow.contrib.data.python.ops import interleave_ops
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
@ -209,6 +211,46 @@ class InterleaveDatasetTest(test.TestCase):
sess.run(get_next) sess.run(get_next)
class InterleaveDatasetSeriazationTest(
dataset_serialization_test_base.DatasetSerializationTestBase):
def _build_iterator_graph(self, input_values, cycle_length, block_length):
repeat_count = 2
return dataset_ops.Dataset.from_tensor_slices(input_values).repeat(
repeat_count).interleave(
lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x),
cycle_length, block_length)
def testSerializationCore(self):
input_values = np.array([4, 5, 6], dtype=np.int64)
num_outputs = np.sum(input_values) * 2
# cycle_length > 1, block_length > 1
cycle_length = 2
block_length = 3
# pylint: disable=g-long-lambda
self.run_core_tests(
lambda: self._build_iterator_graph(
input_values, cycle_length, block_length),
lambda: self._build_iterator_graph(
input_values, cycle_length * 2, block_length * 1),
num_outputs)
# cycle_length = 1
cycle_length = 1
block_length = 3
self.run_core_tests(
lambda: self._build_iterator_graph(
input_values, cycle_length, block_length),
None, num_outputs)
# block_length = 1
cycle_length = 2
block_length = 1
self.run_core_tests(
lambda: self._build_iterator_graph(
input_values, cycle_length, block_length),
None, num_outputs)
# pylint: enable=g-long-lambda
class ParallelInterleaveDatasetTest(test.TestCase): class ParallelInterleaveDatasetTest(test.TestCase):
def setUp(self): def setUp(self):

View File

@ -41,6 +41,7 @@ def try_import(name): # pylint: disable=invalid-name
tf_logging.warning("Could not import %s: %s" % (name, str(e))) tf_logging.warning("Could not import %s: %s" % (name, str(e)))
return module return module
stats = try_import("scipy.stats") stats = try_import("scipy.stats")
@ -62,9 +63,9 @@ class CauchyTest(test.TestCase):
self.assertAllEqual(expected, scale_shape.eval()) self.assertAllEqual(expected, scale_shape.eval())
loc = array_ops.zeros(loc_shape) loc = array_ops.zeros(loc_shape)
scale = array_ops.ones(scale_shape) scale = array_ops.ones(scale_shape)
self.assertAllEqual( self.assertAllEqual(expected,
expected, array_ops.shape(
array_ops.shape(cauchy_lib.Cauchy(loc, scale).sample()).eval()) cauchy_lib.Cauchy(loc, scale).sample()).eval())
def _testParamStaticShapes(self, sample_shape, expected): def _testParamStaticShapes(self, sample_shape, expected):
param_shapes = cauchy_lib.Cauchy.param_static_shapes(sample_shape) param_shapes = cauchy_lib.Cauchy.param_static_shapes(sample_shape)
@ -92,8 +93,7 @@ class CauchyTest(test.TestCase):
cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale) cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale)
log_pdf = cauchy.log_prob(x) log_pdf = cauchy.log_prob(x)
self.assertAllEqual(cauchy.batch_shape_tensor().eval(), self.assertAllEqual(cauchy.batch_shape_tensor().eval(), log_pdf.shape)
log_pdf.shape)
self.assertAllEqual(cauchy.batch_shape_tensor().eval(), self.assertAllEqual(cauchy.batch_shape_tensor().eval(),
log_pdf.eval().shape) log_pdf.eval().shape)
self.assertAllEqual(cauchy.batch_shape, log_pdf.shape) self.assertAllEqual(cauchy.batch_shape, log_pdf.shape)
@ -115,16 +115,15 @@ class CauchyTest(test.TestCase):
with self.test_session(): with self.test_session():
batch_size = 6 batch_size = 6
loc = constant_op.constant([[3.0, -3.0]] * batch_size) loc = constant_op.constant([[3.0, -3.0]] * batch_size)
scale = constant_op.constant([[np.sqrt(10.0), np.sqrt(15.0)]] * scale = constant_op.constant(
batch_size) [[np.sqrt(10.0), np.sqrt(15.0)]] * batch_size)
x = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T x = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T
cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale) cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale)
log_pdf = cauchy.log_prob(x) log_pdf = cauchy.log_prob(x)
log_pdf_values = log_pdf.eval() log_pdf_values = log_pdf.eval()
self.assertEqual(log_pdf.shape, (6, 2)) self.assertEqual(log_pdf.shape, (6, 2))
self.assertAllEqual(cauchy.batch_shape_tensor().eval(), self.assertAllEqual(cauchy.batch_shape_tensor().eval(), log_pdf.shape)
log_pdf.shape)
self.assertAllEqual(cauchy.batch_shape_tensor().eval(), self.assertAllEqual(cauchy.batch_shape_tensor().eval(),
log_pdf.eval().shape) log_pdf.eval().shape)
self.assertAllEqual(cauchy.batch_shape, log_pdf.shape) self.assertAllEqual(cauchy.batch_shape, log_pdf.shape)
@ -248,8 +247,7 @@ class CauchyTest(test.TestCase):
cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale) cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale)
entropy = cauchy.entropy() entropy = cauchy.entropy()
self.assertAllEqual(cauchy.batch_shape_tensor().eval(), self.assertAllEqual(cauchy.batch_shape_tensor().eval(), entropy.shape)
entropy.shape)
self.assertAllEqual(cauchy.batch_shape_tensor().eval(), self.assertAllEqual(cauchy.batch_shape_tensor().eval(),
entropy.eval().shape) entropy.eval().shape)
self.assertAllEqual(cauchy.batch_shape, entropy.shape) self.assertAllEqual(cauchy.batch_shape, entropy.shape)
@ -257,7 +255,7 @@ class CauchyTest(test.TestCase):
if not stats: if not stats:
return return
expected_entropy = stats.cauchy(loc, scale).entropy() expected_entropy = stats.cauchy(loc, scale[0]).entropy().reshape((1, 3))
self.assertAllClose(expected_entropy, entropy.eval()) self.assertAllClose(expected_entropy, entropy.eval())
def testCauchyMode(self): def testCauchyMode(self):
@ -368,8 +366,8 @@ class CauchyTest(test.TestCase):
self.assertAllEqual(expected_shape, samples.shape) self.assertAllEqual(expected_shape, samples.shape)
self.assertAllEqual(expected_shape, sample_values.shape) self.assertAllEqual(expected_shape, sample_values.shape)
expected_shape = (tensor_shape.TensorShape( expected_shape = (
[n.eval()]).concatenate(cauchy.batch_shape)) tensor_shape.TensorShape([n.eval()]).concatenate(cauchy.batch_shape))
self.assertAllEqual(expected_shape, samples.shape) self.assertAllEqual(expected_shape, samples.shape)
self.assertAllEqual(expected_shape, sample_values.shape) self.assertAllEqual(expected_shape, sample_values.shape)
@ -385,18 +383,18 @@ class CauchyTest(test.TestCase):
samples = cauchy.sample(n) samples = cauchy.sample(n)
sample_values = samples.eval() sample_values = samples.eval()
self.assertEqual(samples.shape, (100000, batch_size, 2)) self.assertEqual(samples.shape, (100000, batch_size, 2))
self.assertAllClose(np.median(sample_values[:, 0, 0]), self.assertAllClose(
loc_v[0], atol=1e-1) np.median(sample_values[:, 0, 0]), loc_v[0], atol=1e-1)
self.assertAllClose(np.median(sample_values[:, 0, 1]), self.assertAllClose(
loc_v[1], atol=1e-1) np.median(sample_values[:, 0, 1]), loc_v[1], atol=1e-1)
expected_shape = tensor_shape.TensorShape([n.eval()]).concatenate( expected_shape = tensor_shape.TensorShape([n.eval()]).concatenate(
tensor_shape.TensorShape(cauchy.batch_shape_tensor().eval())) tensor_shape.TensorShape(cauchy.batch_shape_tensor().eval()))
self.assertAllEqual(expected_shape, samples.shape) self.assertAllEqual(expected_shape, samples.shape)
self.assertAllEqual(expected_shape, sample_values.shape) self.assertAllEqual(expected_shape, sample_values.shape)
expected_shape = (tensor_shape.TensorShape( expected_shape = (
[n.eval()]).concatenate(cauchy.batch_shape)) tensor_shape.TensorShape([n.eval()]).concatenate(cauchy.batch_shape))
self.assertAllEqual(expected_shape, samples.shape) self.assertAllEqual(expected_shape, samples.shape)
self.assertAllEqual(expected_shape, sample_values.shape) self.assertAllEqual(expected_shape, sample_values.shape)
@ -428,9 +426,12 @@ class CauchyTest(test.TestCase):
self.assertEqual(cauchy.event_shape, ()) self.assertEqual(cauchy.event_shape, ())
self.assertAllEqual(cauchy.event_shape_tensor().eval(), []) self.assertAllEqual(cauchy.event_shape_tensor().eval(), [])
self.assertAllEqual( self.assertAllEqual(
sess.run(cauchy.batch_shape_tensor(), sess.run(
feed_dict={loc: 5.0, cauchy.batch_shape_tensor(),
scale: [1.0, 2.0]}), [2]) feed_dict={
loc: 5.0,
scale: [1.0, 2.0]
}), [2])
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -30,7 +30,6 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import distribution
__all__ = [ __all__ = [
"Cauchy", "Cauchy",
] ]
@ -97,7 +96,7 @@ class Cauchy(distribution.Distribution):
validate_args=False, validate_args=False,
allow_nan_stats=True, allow_nan_stats=True,
name="Cauchy"): name="Cauchy"):
"""Construct Cauchy distributions with loc and and scale `loc` and `scale`. """Construct Cauchy distributions.
The parameters `loc` and `scale` must be shaped in a way that supports The parameters `loc` and `scale` must be shaped in a way that supports
broadcasting (e.g. `loc + scale` is a valid operation). broadcasting (e.g. `loc + scale` is a valid operation).
@ -121,8 +120,8 @@ class Cauchy(distribution.Distribution):
""" """
parameters = locals() parameters = locals()
with ops.name_scope(name, values=[loc, scale]): with ops.name_scope(name, values=[loc, scale]):
with ops.control_dependencies([check_ops.assert_positive(scale)] if with ops.control_dependencies([check_ops.assert_positive(scale)]
validate_args else []): if validate_args else []):
self._loc = array_ops.identity(loc, name="loc") self._loc = array_ops.identity(loc, name="loc")
self._scale = array_ops.identity(scale, name="scale") self._scale = array_ops.identity(scale, name="scale")
check_ops.assert_same_float_dtype([self._loc, self._scale]) check_ops.assert_same_float_dtype([self._loc, self._scale])
@ -138,8 +137,8 @@ class Cauchy(distribution.Distribution):
@staticmethod @staticmethod
def _param_shapes(sample_shape): def _param_shapes(sample_shape):
return dict( return dict(
zip(("loc", "scale"), ([ops.convert_to_tensor( zip(("loc", "scale"),
sample_shape, dtype=dtypes.int32)] * 2))) ([ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)] * 2)))
@property @property
def loc(self): def loc(self):
@ -153,13 +152,10 @@ class Cauchy(distribution.Distribution):
def _batch_shape_tensor(self): def _batch_shape_tensor(self):
return array_ops.broadcast_dynamic_shape( return array_ops.broadcast_dynamic_shape(
array_ops.shape(self.loc), array_ops.shape(self.loc), array_ops.shape(self.scale))
array_ops.shape(self.scale))
def _batch_shape(self): def _batch_shape(self):
return array_ops.broadcast_static_shape( return array_ops.broadcast_static_shape(self.loc.shape, self.scale.shape)
self.loc.shape,
self.scale.shape)
def _event_shape_tensor(self): def _event_shape_tensor(self):
return constant_op.constant([], dtype=dtypes.int32) return constant_op.constant([], dtype=dtypes.int32)

View File

@ -116,6 +116,7 @@ py_library(
deps = [ deps = [
":clip_weights", ":clip_weights",
":conditioning_utils", ":conditioning_utils",
":tensor_pool",
":virtual_batchnorm", ":virtual_batchnorm",
"//tensorflow/python:util", "//tensorflow/python:util",
], ],
@ -219,6 +220,37 @@ py_test(
], ],
) )
py_library(
name = "tensor_pool",
srcs = [
"python/features/python/tensor_pool.py",
"python/features/python/tensor_pool_impl.py",
],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:data_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:random_ops",
"//tensorflow/python:util",
],
)
py_test(
name = "tensor_pool_test",
srcs = ["python/features/python/tensor_pool_test.py"],
srcs_version = "PY2AND3",
deps = [
":tensor_pool",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//third_party/py/numpy",
],
)
py_library( py_library(
name = "virtual_batchnorm", name = "virtual_batchnorm",
srcs = [ srcs = [

View File

@ -0,0 +1,35 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A tensor pool stores values from an input tensor and returns a stored one.
See the following papers for more details.
1) `Learning from simulated and unsupervised images through adversarial
training` (https://arxiv.org/abs/1612.07828).
2) `Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial
Networks` (https://arxiv.org/abs/1703.10593).
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.gan.python.features.python import tensor_pool_impl
# pylint: disable=wildcard-import
from tensorflow.contrib.gan.python.features.python.tensor_pool_impl import *
# pylint: enable=wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
__all__ = tensor_pool_impl.__all__
remove_undocumented(__name__, __all__)

View File

@ -0,0 +1,118 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A tensor pool stores values from an input tensor and returns a stored one.
We use this to keep a history of values created by a generator, such that
a discriminator can randomly be trained on some older samples, not just the
current one. This can help to not let the discriminator get too far ahead of the
generator and also to keep the system from oscilating, if the discriminator
forgets too fast what past samples from the generator looked like.
See the following papers for more details.
1) `Learning from simulated and unsupervised images through adversarial
training` (https://arxiv.org/abs/1612.07828).
2) `Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial
Networks` (https://arxiv.org/abs/1703.10593).
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import random_ops
__all__ = [
'tensor_pool',
]
def tensor_pool(input_value,
pool_size,
pooling_probability=0.5,
name='tensor_pool'):
"""Queue storing input values and returning random previously stored ones.
Every time the returned `output_value` is evaluated, `input_value` is
evaluated and its value either directly returned (with
`1-pooling_probability`) or stored in the pool and a random one of the samples
currently in the pool is popped and returned. As long as the pool in not fully
filled, the input_value is always directly returned, as well as stored in the
pool. Note during inference / testing, it may be appropriate to set
`pool_size` = 0 or `pooling_probability` = 0.
Args:
input_value: A `Tensor` from which to read values to be pooled.
pool_size: An integer specifying the maximum size of the pool.
pooling_probability: A float `Tensor` specifying the probability of getting
a value from the pool, as opposed to just the current input.
name: A string prefix for the name scope for all tensorflow ops.
Returns:
A `Tensor` which is with given probability either the `input_value` or a
randomly chosen sample that was previously inserted in the pool.
Raises:
ValueError: If `pool_size` is negative.
"""
pool_size = int(pool_size)
if pool_size < 0:
raise ValueError('`pool_size` is negative.')
elif pool_size == 0:
return input_value
with ops.name_scope('{}_pool_queue'.format(name),
values=[input_value, pooling_probability]):
pool_queue = data_flow_ops.RandomShuffleQueue(
capacity=pool_size,
min_after_dequeue=0,
dtypes=[input_value.dtype],
shapes=None)
# In pseudeo code this code does the following:
# if not pool_full:
# enqueue(input_value)
# return input_value
# else
# dequeue_value = dequeue_random_sample()
# enqueue(input_value)
# if rand() < pooling_probability:
# return dequeue_value
# else
# return input_value
def _get_input_value_pooled():
enqueue_op = pool_queue.enqueue(input_value)
with ops.control_dependencies([enqueue_op]):
return array_ops.identity(input_value)
def _get_random_pool_value_and_enqueue_input():
dequeue_value = pool_queue.dequeue()
with ops.control_dependencies([dequeue_value]):
enqueue_op = pool_queue.enqueue(input_value)
with ops.control_dependencies([enqueue_op]):
prob = random_ops.random_uniform(
(), dtype=dtypes.float32) < pooling_probability
return control_flow_ops.cond(prob, lambda: dequeue_value,
lambda: input_value)
output_value = control_flow_ops.cond(
pool_queue.size() < pool_size, _get_input_value_pooled,
_get_random_pool_value_and_enqueue_input)
return output_value

View File

@ -0,0 +1,94 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tf.contrib.gan.python.features.tensor_pool."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.contrib.gan.python.features.python import tensor_pool_impl as tensor_pool
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class TensorPoolTest(test.TestCase):
def test_pool_unknown_input_shape(self):
"""Checks that `input_value` can have unknown shape."""
input_value = array_ops.placeholder(
dtype=dtypes.int32, shape=[None, None, 3])
output_value = tensor_pool.tensor_pool(input_value, pool_size=10)
with self.test_session(use_gpu=True) as session:
for i in range(10):
session.run(output_value, {input_value: [[[i] * 3]]})
session.run(output_value, {input_value: [[[i] * 3] * 2]})
session.run(output_value, {input_value: [[[i] * 3] * 5] * 2})
def test_pool_sequence(self):
"""Checks that values are pooled and returned maximally twice."""
input_value = array_ops.placeholder(dtype=dtypes.int32, shape=[])
output_value = tensor_pool.tensor_pool(input_value, pool_size=10)
with self.test_session(use_gpu=True) as session:
outs = []
for i in range(50):
out = session.run(output_value, {input_value: i})
outs.append(out)
self.assertLessEqual(out, i)
_, counts = np.unique(outs, return_counts=True)
# Check that each value is returned maximally twice.
self.assertTrue((counts <= 2).all())
def test_never_pool(self):
"""Checks that setting `pooling_probability` to zero works."""
input_value = array_ops.placeholder(dtype=dtypes.int32, shape=[])
output_value = tensor_pool.tensor_pool(
input_value, pool_size=10, pooling_probability=0.0)
with self.test_session(use_gpu=True) as session:
for i in range(50):
out = session.run(output_value, {input_value: i})
self.assertEqual(out, i)
def test_pooling_probability(self):
"""Checks that `pooling_probability` works."""
input_value = array_ops.placeholder(dtype=dtypes.int32, shape=[])
pool_size = 10
pooling_probability = 0.2
output_value = tensor_pool.tensor_pool(
input_value,
pool_size=pool_size,
pooling_probability=pooling_probability)
with self.test_session(use_gpu=True) as session:
not_pooled = 0
total = 1000
for i in range(total):
out = session.run(output_value, {input_value: i})
if out == i:
not_pooled += 1
self.assertAllClose(
(not_pooled - pool_size) / (total - pool_size),
1 - pooling_probability,
atol=0.03)
if __name__ == '__main__':
test.main()

View File

@ -40,6 +40,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
cov_ema_decay, cov_ema_decay,
damping, damping,
layer_collection, layer_collection,
var_list=None,
momentum=0., momentum=0.,
momentum_type="regular", momentum_type="regular",
norm_constraint=None, norm_constraint=None,
@ -66,6 +67,9 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
blocks, kronecker factors, and losses associated with the blocks, kronecker factors, and losses associated with the
graph. The layer_collection cannot be modified after KfacOptimizer's graph. The layer_collection cannot be modified after KfacOptimizer's
initialization. initialization.
var_list: Optional list or tuple of variables to train. Defaults to the
list of variables collected in the graph under the key
`GraphKeys.TRAINABLE_VARIABLES`.
momentum: The momentum value for this optimizer. Only applies when momentum: The momentum value for this optimizer. Only applies when
momentum_type is 'regular' or 'adam'. (Default: 0) momentum_type is 'regular' or 'adam'. (Default: 0)
momentum_type: The type of momentum to use in this optimizer, one of momentum_type: The type of momentum to use in this optimizer, one of
@ -96,8 +100,8 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
or 'adam'. or 'adam'.
""" """
# We may consider determining the set of variables some other way, but for variables = var_list
# now it's just all the trainable variables. if variables is None:
variables = tf_variables.trainable_variables() variables = tf_variables.trainable_variables()
self._fisher_est = est.FisherEstimator( self._fisher_est = est.FisherEstimator(
@ -123,7 +127,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
raise ValueError("Momentum must be unspecified if using a momentum_type " raise ValueError("Momentum must be unspecified if using a momentum_type "
"other than 'regular' or 'adam'.") "other than 'regular' or 'adam'.")
self._momentum = ops.convert_to_tensor(momentum, name="momentum") self._momentum = momentum
self._momentum_type = momentum_type self._momentum_type = momentum_type
self._norm_constraint = norm_constraint self._norm_constraint = norm_constraint
@ -313,14 +317,17 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
self._batch_size, dtype=fft_precon_grads[0].dtype) self._batch_size, dtype=fft_precon_grads[0].dtype)
# compute the entries of the 2x2 matrix # compute the entries of the 2x2 matrix
m_11 = (_inner_product_list(fft_precon_grads, fft_precon_grads) / batch_size m_11 = (
+ self.damping * _inner_product_list(precon_grads, precon_grads)) _inner_product_list(fft_precon_grads, fft_precon_grads) / batch_size +
self.damping * _inner_product_list(precon_grads, precon_grads))
m_21 = (_inner_product_list(fft_prev_updates, fft_precon_grads) / batch_size m_21 = (
+ self.damping * _inner_product_list(prev_updates, precon_grads)) _inner_product_list(fft_prev_updates, fft_precon_grads) / batch_size +
self.damping * _inner_product_list(prev_updates, precon_grads))
m_22 = (_inner_product_list(fft_prev_updates, fft_prev_updates) / batch_size m_22 = (
+ self.damping * _inner_product_list(prev_updates, prev_updates)) _inner_product_list(fft_prev_updates, fft_prev_updates) / batch_size +
self.damping * _inner_product_list(prev_updates, prev_updates))
def non_zero_prevupd_case(): def non_zero_prevupd_case():
r"""Computes optimal (alpha, mu) given non-zero previous update. r"""Computes optimal (alpha, mu) given non-zero previous update.
@ -406,8 +413,8 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
grads = list(grad for (grad, _) in grads_and_vars) grads = list(grad for (grad, _) in grads_and_vars)
variables = list(var for (_, var) in grads_and_vars) variables = list(var for (_, var) in grads_and_vars)
# previous updates are the negative velocities (up to scaling by LR) # previous updates are the negative velocities (up to scaling by LR)
prev_updates = list(-self._zeros_slot(var, "velocity", self._name) prev_updates = list(
for var in variables) -self._zeros_slot(var, "velocity", self._name) for var in variables)
# Compute optimal velocity update parameters according to quadratic model # Compute optimal velocity update parameters according to quadratic model
alpha, mu, _ = self._compute_qmodel_hyperparams( alpha, mu, _ = self._compute_qmodel_hyperparams(

View File

@ -28,7 +28,6 @@ from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops from tensorflow.python.ops import random_ops
# Method used for inverting matrices. # Method used for inverting matrices.
POSDEF_INV_METHOD = "cholesky" POSDEF_INV_METHOD = "cholesky"
@ -202,9 +201,18 @@ def posdef_inv_cholesky(tensor, identity, damping):
return linalg_ops.cholesky_solve(chol, identity) return linalg_ops.cholesky_solve(chol, identity)
def posdef_inv_eig(tensor, identity, damping):
"""Computes inverse(tensor + damping * identity) with eigendecomposition."""
eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(
tensor + damping * identity)
return math_ops.matmul(
eigenvectors / eigenvalues, eigenvectors, transpose_b=True)
posdef_inv_funcs = { posdef_inv_funcs = {
"matrix_inverse": posdef_inv_matrix_inverse, "matrix_inverse": posdef_inv_matrix_inverse,
"cholesky": posdef_inv_cholesky, "cholesky": posdef_inv_cholesky,
"eig": posdef_inv_eig,
} }
@ -261,8 +269,8 @@ def fwd_gradients(ys, xs, grad_xs=None, stop_gradients=None):
# generated by the first gradients_impl.gradients call. # generated by the first gradients_impl.gradients call.
us = [array_ops.zeros_like(y) + float("nan") for y in ys] us = [array_ops.zeros_like(y) + float("nan") for y in ys]
dydxs = gradients_impl.gradients(ys, xs, grad_ys=us, dydxs = gradients_impl.gradients(
stop_gradients=stop_gradients) ys, xs, grad_ys=us, stop_gradients=stop_gradients)
# Deal with strange types that gradients_impl.gradients returns but can't # Deal with strange types that gradients_impl.gradients returns but can't
# deal with. # deal with.
@ -278,3 +286,6 @@ def fwd_gradients(ys, xs, grad_xs=None, stop_gradients=None):
dysdx = gradients_impl.gradients(dydxs, us, grad_ys=grad_xs) dysdx = gradients_impl.gradients(dydxs, us, grad_ys=grad_xs)
return dysdx return dysdx
# TODO(b/69623235): Add a function for finding tensors that share gradients
# to eliminate redundant fisher factor computations.

View File

@ -309,7 +309,6 @@ def _fused_batch_norm(inputs,
new_shape = [-1, channels, 1, 1] new_shape = [-1, channels, 1, 1]
inputs = array_ops.reshape(inputs, new_shape) inputs = array_ops.reshape(inputs, new_shape)
inputs_shape = inputs.get_shape() inputs_shape = inputs.get_shape()
dtype = inputs.dtype.base_dtype
if data_format == DATA_FORMAT_NHWC: if data_format == DATA_FORMAT_NHWC:
params_shape = inputs_shape[-1:] params_shape = inputs_shape[-1:]
else: else:

View File

@ -1779,7 +1779,8 @@ class BatchNormTest(test.TestCase):
dtype = dtypes.float32 dtype = dtypes.float32
height, width = 3, 3 height, width = 3, 3
with self.test_session(): with self.test_session():
images = np.random.uniform(size=(5, height, width, 3)).astype(dtype.as_numpy_dtype) images = np.random.uniform(size=(5, height, width, 3)).astype(
dtype.as_numpy_dtype)
output = _layers.batch_norm(images, fused=fused) output = _layers.batch_norm(images, fused=fused)
expected_name = ('BatchNorm/FusedBatchNorm' if fused else expected_name = ('BatchNorm/FusedBatchNorm' if fused else
'BatchNorm/batchnorm') 'BatchNorm/batchnorm')
@ -2665,18 +2666,18 @@ class BatchNormTest(test.TestCase):
# Test case for 11673 # Test case for 11673
with self.test_session() as sess: with self.test_session() as sess:
a_32 = array_ops.placeholder(dtypes.float32, shape=(10, 10, 10, 10)) a_32 = array_ops.placeholder(dtypes.float32, shape=(10, 10, 10, 10))
b_32 = _layers.batch_norm(a_32, center=False, data_format='NCHW', _layers.batch_norm(
zero_debias_moving_mean=True) a_32, center=False, data_format='NCHW', zero_debias_moving_mean=True)
a_16 = array_ops.placeholder(dtypes.float16, shape=(10, 10, 10, 10)) a_16 = array_ops.placeholder(dtypes.float16, shape=(10, 10, 10, 10))
b_16 = _layers.batch_norm(a_16, center=False, data_format='NCHW', _layers.batch_norm(
zero_debias_moving_mean=True) a_16, center=False, data_format='NCHW', zero_debias_moving_mean=True)
sess.run(variables_lib.global_variables_initializer()) sess.run(variables_lib.global_variables_initializer())
def testVariablesAreFloat32(self): def testVariablesAreFloat32(self):
height, width = 3, 3 height, width = 3, 3
with self.test_session(): with self.test_session():
images = random_ops.random_uniform((5, height, width, 3), images = random_ops.random_uniform(
seed=1, dtype=dtypes.float16) (5, height, width, 3), seed=1, dtype=dtypes.float16)
_layers.batch_norm(images, scale=True) _layers.batch_norm(images, scale=True)
beta = variables.get_variables_by_name('beta')[0] beta = variables.get_variables_by_name('beta')[0]
gamma = variables.get_variables_by_name('gamma')[0] gamma = variables.get_variables_by_name('gamma')[0]
@ -2691,17 +2692,13 @@ class BatchNormTest(test.TestCase):
channels = shape[1] channels = shape[1]
images = np.arange(np.product(shape), dtype=dtype).reshape(shape) images = np.arange(np.product(shape), dtype=dtype).reshape(shape)
beta = init_ops.constant_initializer( beta = init_ops.constant_initializer(
np.arange( np.arange(2, channels + 2, dtype=np.float32))
2, channels + 2, dtype=np.float32))
gamma = init_ops.constant_initializer( gamma = init_ops.constant_initializer(
np.arange( np.arange(10, channels + 10, dtype=np.float32) * 2.0)
10, channels + 10, dtype=np.float32) * 2.0)
mean = init_ops.constant_initializer( mean = init_ops.constant_initializer(
np.arange( np.arange(3, channels + 3, dtype=np.float32) * 5.0)
3, channels + 3, dtype=np.float32) * 5.0)
variance = init_ops.constant_initializer( variance = init_ops.constant_initializer(
np.arange( np.arange(1, channels + 1, dtype=np.float32) * 4.0)
1, channels + 1, dtype=np.float32) * 4.0)
output = _layers.batch_norm( output = _layers.batch_norm(
images, images,
fused=True, fused=True,
@ -2726,7 +2723,6 @@ class BatchNormTest(test.TestCase):
res_16 = self._runFusedBatchNorm(shape, np.float16) res_16 = self._runFusedBatchNorm(shape, np.float16)
self.assertAllClose(res_32, res_16, rtol=1e-3) self.assertAllClose(res_32, res_16, rtol=1e-3)
def testAdjustmentCreated(self): def testAdjustmentCreated(self):
# Tests that the adjustment is appropriately passed to and used by the core # Tests that the adjustment is appropriately passed to and used by the core
# BN layer. # BN layer.

View File

@ -28,7 +28,6 @@ import six
from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
@ -369,7 +368,8 @@ class DataFeeder(object):
if x_is_dict: if x_is_dict:
num_samples = list(self._x.values())[0].shape[0] num_samples = list(self._x.values())[0].shape[0]
elif tensor_util.is_tensor(self._x): elif tensor_util.is_tensor(self._x):
num_samples = self._x.shape[0].value # shape will be a Dimension, extract an int num_samples = self._x.shape[
0].value # shape will be a Dimension, extract an int
else: else:
num_samples = self._x.shape[0] num_samples = self._x.shape[0]

View File

@ -251,8 +251,9 @@ class SdcaModel(object):
result_dense = 0.0 result_dense = 0.0
for i in range(len(dense_variables)): for i in range(len(dense_variables)):
result_dense += math_ops.matmul( result_dense += math_ops.matmul(dense_features[i],
dense_features[i], array_ops.expand_dims(dense_variables[i], -1)) array_ops.expand_dims(
dense_variables[i], -1))
# Reshaping to allow shape inference at graph construction time. # Reshaping to allow shape inference at graph construction time.
return array_ops.reshape(result_dense, [-1]) + result_sparse return array_ops.reshape(result_dense, [-1]) + result_sparse

View File

@ -164,8 +164,8 @@ def toco_convert(input_data,
toco = _toco_flags_pb2.TocoFlags() toco = _toco_flags_pb2.TocoFlags()
toco.input_format = input_format toco.input_format = input_format
toco.output_format = output_format toco.output_format = output_format
toco.drop_control_dependency = drop_control_dependency
model = _model_flags_pb2.ModelFlags() model = _model_flags_pb2.ModelFlags()
model.drop_control_dependency = drop_control_dependency
toco.inference_type = inference_type toco.inference_type = inference_type
for idx, input_tensor in enumerate(input_tensors): for idx, input_tensor in enumerate(input_tensors):
if input_tensor.dtype == _dtypes.float32: if input_tensor.dtype == _dtypes.float32:

View File

@ -40,6 +40,7 @@ from six import StringIO
# TODO(aselle): Disable GPU for now # TODO(aselle): Disable GPU for now
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
# pylint: disable=g-import-not-at-top
import tensorflow as tf import tensorflow as tf
from google.protobuf import text_format from google.protobuf import text_format
# TODO(aselle): switch to TensorFlow's resource_loader # TODO(aselle): switch to TensorFlow's resource_loader
@ -383,7 +384,7 @@ def make_zip_of_tests(zip_path,
report["toco_log"] = "" report["toco_log"] = ""
tf.reset_default_graph() tf.reset_default_graph()
with tf.device('/cpu:0'): with tf.device("/cpu:0"):
try: try:
inputs, outputs = make_graph(param_dict_real) inputs, outputs = make_graph(param_dict_real)
except (tf.errors.UnimplementedError, tf.errors.InvalidArgumentError, except (tf.errors.UnimplementedError, tf.errors.InvalidArgumentError,

View File

@ -194,7 +194,6 @@ struct ParsedModelFlags {
Arg<string> input_data_type; Arg<string> input_data_type;
Arg<string> input_data_types; Arg<string> input_data_types;
Arg<bool> variable_batch = Arg<bool>(false); Arg<bool> variable_batch = Arg<bool>(false);
Arg<bool> drop_control_dependency = Arg<bool>(false);
Arg<toco::IntList> input_shape; Arg<toco::IntList> input_shape;
Arg<toco::StringMapList> rnn_states; Arg<toco::StringMapList> rnn_states;
Arg<toco::StringMapList> model_checks; Arg<toco::StringMapList> model_checks;
@ -224,6 +223,7 @@ struct ParsedTocoFlags {
// Deprecated flags // Deprecated flags
Arg<string> input_type; Arg<string> input_type;
Arg<string> input_types; Arg<string> input_types;
Arg<bool> drop_control_dependency = Arg<bool>(false);
}; };
} // namespace toco } // namespace toco

View File

@ -35,8 +35,11 @@ limitations under the License.
#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
using tensorflow::DT_BOOL;
using tensorflow::DT_FLOAT; using tensorflow::DT_FLOAT;
using tensorflow::DT_INT32; using tensorflow::DT_INT32;
using tensorflow::DT_INT64;
using tensorflow::DT_UINT8;
using tensorflow::GraphDef; using tensorflow::GraphDef;
using tensorflow::TensorProto; using tensorflow::TensorProto;
@ -1500,10 +1503,29 @@ void ConvertOperator(const Model& model, const Operator& src_op,
} }
} }
void AddPlaceholder(const string& name, GraphDef* tensorflow_graph) { void AddPlaceholder(const string& name, ArrayDataType type,
GraphDef* tensorflow_graph) {
auto* placeholder = tensorflow_graph->add_node(); auto* placeholder = tensorflow_graph->add_node();
placeholder->set_op("Placeholder"); placeholder->set_op("Placeholder");
switch (type) {
case ArrayDataType::kBool:
(*placeholder->mutable_attr())["dtype"].set_type(DT_BOOL);
break;
case ArrayDataType::kFloat:
(*placeholder->mutable_attr())["dtype"].set_type(DT_FLOAT); (*placeholder->mutable_attr())["dtype"].set_type(DT_FLOAT);
break;
case ArrayDataType::kUint8:
(*placeholder->mutable_attr())["dtype"].set_type(DT_UINT8);
break;
case ArrayDataType::kInt32:
(*placeholder->mutable_attr())["dtype"].set_type(DT_INT32);
break;
case ArrayDataType::kInt64:
(*placeholder->mutable_attr())["dtype"].set_type(DT_INT64);
break;
default:
LOG(FATAL) << "Unexpected data type in array \"" << name << "\"";
}
placeholder->set_name(name); placeholder->set_name(name);
} }
@ -1531,7 +1553,9 @@ void AddPlaceholderForRNNState(const Model& model, const string& name, int size,
void ExportTensorFlowGraphDefImplementation(const Model& model, void ExportTensorFlowGraphDefImplementation(const Model& model,
GraphDef* tensorflow_graph) { GraphDef* tensorflow_graph) {
for (const auto& input_array : model.flags.input_arrays()) { for (const auto& input_array : model.flags.input_arrays()) {
AddPlaceholder(input_array.name(), tensorflow_graph); AddPlaceholder(input_array.name(),
model.arrays.at(input_array.name())->data_type,
tensorflow_graph);
} }
for (const auto& rnn_state : model.flags.rnn_states()) { for (const auto& rnn_state : model.flags.rnn_states()) {
AddPlaceholderForRNNState(model, rnn_state.state_array(), rnn_state.size(), AddPlaceholderForRNNState(model, rnn_state.state_array(), rnn_state.size(),

File diff suppressed because it is too large Load Diff

View File

@ -23,11 +23,19 @@ limitations under the License.
namespace toco { namespace toco {
std::unique_ptr<Model> ImportTensorFlowGraphDef( struct TensorFlowImportFlags {
const ModelFlags& model_flags, const tensorflow::GraphDef& graph_def); // If true, control dependencies will be dropped immediately
// during the import of the TensorFlow GraphDef.
bool drop_control_dependency = false;
};
std::unique_ptr<Model> ImportTensorFlowGraphDef( std::unique_ptr<Model> ImportTensorFlowGraphDef(
const ModelFlags& model_flags, const string& input_file_contents); const ModelFlags& model_flags, const TensorFlowImportFlags& tf_import_flags,
const tensorflow::GraphDef& graph_def);
std::unique_ptr<Model> ImportTensorFlowGraphDef(
const ModelFlags& model_flags, const TensorFlowImportFlags& tf_import_flags,
const string& input_file_contents);
} // namespace toco } // namespace toco

View File

@ -112,13 +112,6 @@ bool ParseModelFlagsFromCommandLineFlags(
"exclusive " "exclusive "
"with the 'batch' field: at most one of these two fields can be " "with the 'batch' field: at most one of these two fields can be "
"set."), "set."),
Flag(
"drop_control_dependency",
parsed_flags.drop_control_dependency.bind(),
parsed_flags.drop_control_dependency.default_value(),
"If true, ignore control dependency requirements in input TensorFlow "
"GraphDef. Otherwise an error will be raised upon control dependency "
"inputs."),
Flag("rnn_states", parsed_flags.rnn_states.bind(), Flag("rnn_states", parsed_flags.rnn_states.bind(),
parsed_flags.rnn_states.default_value(), ""), parsed_flags.rnn_states.default_value(), ""),
Flag("model_checks", parsed_flags.model_checks.bind(), Flag("model_checks", parsed_flags.model_checks.bind(),
@ -316,7 +309,6 @@ void ReadModelFlagsFromCommandLineFlags(
} while (false) } while (false)
READ_MODEL_FLAG(variable_batch); READ_MODEL_FLAG(variable_batch);
READ_MODEL_FLAG(drop_control_dependency);
#undef READ_MODEL_FLAG #undef READ_MODEL_FLAG

View File

@ -138,8 +138,4 @@ message ModelFlags {
optional int32 count_max = 3 [default = -1]; optional int32 count_max = 3 [default = -1];
} }
repeated ModelCheck model_checks = 14; repeated ModelCheck model_checks = 14;
// If true, ignore control dependency requirements in input TensorFlow
// GraphDef. Otherwise an error will be raised upon control dependency inputs.
optional bool drop_control_dependency = 15;
} }

View File

@ -103,6 +103,13 @@ bool ParseTocoFlagsFromCommandLineFlags(
parsed_flags.allow_custom_ops.default_value(), parsed_flags.allow_custom_ops.default_value(),
"If true, allow TOCO to create TF Lite Custom operators for all the" "If true, allow TOCO to create TF Lite Custom operators for all the"
"unsupported Tensorflow ops."), "unsupported Tensorflow ops."),
Flag(
"drop_control_dependency",
parsed_flags.drop_control_dependency.bind(),
parsed_flags.drop_control_dependency.default_value(),
"If true, ignore control dependency requirements in input TensorFlow "
"GraphDef. Otherwise an error will be raised upon control dependency "
"inputs."),
}; };
bool asked_for_help = bool asked_for_help =
*argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help")); *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
@ -163,6 +170,7 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags,
READ_TOCO_FLAG(drop_fake_quant, FlagRequirement::kNone); READ_TOCO_FLAG(drop_fake_quant, FlagRequirement::kNone);
READ_TOCO_FLAG(reorder_across_fake_quant, FlagRequirement::kNone); READ_TOCO_FLAG(reorder_across_fake_quant, FlagRequirement::kNone);
READ_TOCO_FLAG(allow_custom_ops, FlagRequirement::kNone); READ_TOCO_FLAG(allow_custom_ops, FlagRequirement::kNone);
READ_TOCO_FLAG(drop_control_dependency, FlagRequirement::kNone);
// Deprecated flag handling. // Deprecated flag handling.
if (parsed_toco_flags.input_type.specified()) { if (parsed_toco_flags.input_type.specified()) {

View File

@ -36,7 +36,7 @@ enum FileFormat {
// are not normally encoded in model files and in general may not be thought // are not normally encoded in model files and in general may not be thought
// of as properties of models, instead describing how models are to be // of as properties of models, instead describing how models are to be
// processed in the context of the present tooling job. // processed in the context of the present tooling job.
// Next Id: 12 // Next Id: 13
message TocoFlags { message TocoFlags {
// Input file format // Input file format
optional FileFormat input_format = 1; optional FileFormat input_format = 1;
@ -128,4 +128,12 @@ message TocoFlags {
// If true, allow TOCO to create TF Lite Custom operators for all the // If true, allow TOCO to create TF Lite Custom operators for all the
// unsupported Tensorflow ops. // unsupported Tensorflow ops.
optional bool allow_custom_ops = 10; optional bool allow_custom_ops = 10;
// Applies only to the case when the input format is TENSORFLOW_GRAPHDEF.
// If true, then control dependencies will be immediately dropped during
// import.
// If not set, the default behavior is as follows:
// - Default to false if the output format is TENSORFLOW_GRAPHDEF.
// - Default to true in all other cases.
optional bool drop_control_dependency = 12;
} }

View File

@ -85,38 +85,57 @@ void MakeGeneralGraphTransformationsSet(
transformations->Add(new MakeInitialDequantizeOperator); transformations->Add(new MakeInitialDequantizeOperator);
} }
void SetArrayFinalDataTypes(const TocoFlags& toco_flags, Model* model) { bool SupportsQuantization(FileFormat format) {
const bool output_supports_only_float = return (format == GRAPHVIZ_DOT || format == TFLITE);
toco_flags.output_format() == TENSORFLOW_GRAPHDEF; ;
}
ArrayDataType specified_final_data_type = ArrayDataType::kNone; bool SupportsFusedActivationFunction(FileFormat format) {
return (format == GRAPHVIZ_DOT || format == TFLITE);
}
bool SupportsLstmCell(FileFormat format) {
return (format == TENSORFLOW_GRAPHDEF || format == GRAPHVIZ_DOT);
}
bool SupportsPreallocatedWorkspace(FileFormat format) {
return (format == GRAPHVIZ_DOT || format == TFLITE);
}
bool IsRealValued(toco::ArrayDataType type) {
return static_cast<bool>(type == toco::ArrayDataType::kFloat ||
type == toco::ArrayDataType::kUint8);
}
void SetFinalDataTypeOnInputs(const TocoFlags& toco_flags, Model* model) {
const FileFormat output_format = toco_flags.output_format();
ArrayDataType type;
if (toco_flags.has_inference_input_type()) { if (toco_flags.has_inference_input_type()) {
specified_final_data_type = type = ConvertIODataTypeToArrayDataType(toco_flags.inference_input_type());
ConvertIODataTypeToArrayDataType(toco_flags.inference_input_type());
} else if (toco_flags.has_inference_type()) { } else if (toco_flags.has_inference_type()) {
specified_final_data_type = type = ConvertIODataTypeToArrayDataType(toco_flags.inference_type());
ConvertIODataTypeToArrayDataType(toco_flags.inference_type()); } else if (!SupportsQuantization(output_format)) {
} // Data type is implicitly float for non-quantized formats
ArrayDataType final_data_type = ArrayDataType::kNone; type = ArrayDataType::kFloat;
if (output_supports_only_float) {
QCHECK(specified_final_data_type == ArrayDataType::kNone ||
specified_final_data_type == ArrayDataType::kFloat);
final_data_type = ArrayDataType::kFloat;
} else { } else {
final_data_type = specified_final_data_type; // Nothing to do. Data types stay as-is.
return;
} }
for (int i = 0; i < model->flags.input_arrays_size(); i++) { for (int i = 0; i < model->flags.input_arrays_size(); i++) {
auto* array = model->arrays[model->flags.input_arrays(i).name()].get(); string const& array_name = model->flags.input_arrays(i).name();
auto* array = model->arrays[array_name].get();
// Note that the notion of changing data types only applies to real-numbers // Note that the notion of changing data types only applies to real-numbers
// arrays (see the documentation for inference_input_type). // arrays (see the documentation for inference_input_type).
// TODO(benoitjacob) this is assuming that uint8 arrays are quantized, // TODO(benoitjacob) this is assuming that uint8 arrays are quantized,
// i.e. represent real numbers by means of quantization parameters, // i.e. represent real numbers by means of quantization parameters,
// and not plain integer uint8 input arrays. // and not plain integer uint8 input arrays.
const bool is_real_numbers = array->data_type == ArrayDataType::kFloat || if (!IsRealValued(array->data_type)) {
array->data_type == ArrayDataType::kUint8; // Ignore non-real data types.
if (is_real_numbers) { continue;
array->final_data_type = final_data_type;
} }
array->final_data_type = type;
} }
} }
@ -127,9 +146,16 @@ std::unique_ptr<Model> Import(const TocoFlags& toco_flags,
const string& input_file_contents) { const string& input_file_contents) {
std::unique_ptr<Model> model; std::unique_ptr<Model> model;
switch (toco_flags.input_format()) { switch (toco_flags.input_format()) {
case TENSORFLOW_GRAPHDEF: case TENSORFLOW_GRAPHDEF: {
model = ImportTensorFlowGraphDef(model_flags, input_file_contents); TensorFlowImportFlags tf_import_flags;
tf_import_flags.drop_control_dependency =
toco_flags.has_drop_control_dependency()
? toco_flags.drop_control_dependency()
: (toco_flags.output_format() != TENSORFLOW_GRAPHDEF);
model = ImportTensorFlowGraphDef(model_flags, tf_import_flags,
input_file_contents);
break; break;
}
case TFLITE: case TFLITE:
model = toco::tflite::Import(model_flags, input_file_contents); model = toco::tflite::Import(model_flags, input_file_contents);
ResolveModelFlags(model_flags, model.get()); ResolveModelFlags(model_flags, model.get());
@ -148,23 +174,21 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
const FileFormat output_format = toco_flags.output_format(); const FileFormat output_format = toco_flags.output_format();
const IODataType inference_type = toco_flags.inference_type(); const IODataType inference_type = toco_flags.inference_type();
const bool output_is_tflite = output_format == TFLITE; const bool quantize_output =
SupportsQuantization(output_format) && inference_type == QUANTIZED_UINT8;
const bool output_is_tflite_quantized = if (quantize_output) {
output_is_tflite && inference_type == QUANTIZED_UINT8;
if (output_is_tflite_quantized) {
QCHECK_NE(toco_flags.inference_input_type(), FLOAT) QCHECK_NE(toco_flags.inference_input_type(), FLOAT)
<< "Quantized inference is not allowed with float inputs."; << "Quantized inference is not allowed with float inputs.";
} }
SetArrayFinalDataTypes(toco_flags, model); SetFinalDataTypeOnInputs(toco_flags, model);
GraphTransformationsSet transformations; GraphTransformationsSet transformations;
MakeGeneralGraphTransformationsSet(&transformations); MakeGeneralGraphTransformationsSet(&transformations);
auto* remove_trivial_reshape = new RemoveTrivialReshape; auto* remove_trivial_reshape = new RemoveTrivialReshape;
transformations.Add(remove_trivial_reshape); transformations.Add(remove_trivial_reshape);
if (output_format == TFLITE) { if (SupportsFusedActivationFunction(output_format)) {
transformations.Add(new FuseActivationFunctions); transformations.Add(new FuseActivationFunctions);
} else { } else {
transformations.Add(new UnfuseActivationFunctions); transformations.Add(new UnfuseActivationFunctions);
@ -183,25 +207,24 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
// easy to pass a new toco flag. Once that is resolved on the DarwiNN // easy to pass a new toco flag. Once that is resolved on the DarwiNN
// tests side, the special-casing of DarwiNN here can go away. // tests side, the special-casing of DarwiNN here can go away.
// TODO(benoitjacob): so drop it when we can. // TODO(benoitjacob): so drop it when we can.
if ((output_is_tflite_quantized && if ((quantize_output && toco_flags.reorder_across_fake_quant())) {
toco_flags.reorder_across_fake_quant())) {
transformations.Add(new DropFakeQuant); transformations.Add(new DropFakeQuant);
} }
} }
transformations.Add(new ConvertPureConvToDepthwise); transformations.Add(new ConvertPureConvToDepthwise);
// TFLite export does not yet support fused LSTM cell. // TFLite export does not yet support fused LSTM cell.
if (output_format == TENSORFLOW_GRAPHDEF) { if (SupportsLstmCell(output_format)) {
transformations.Add(new IdentifyLstmCell); transformations.Add(new IdentifyLstmCell);
} }
transformations.Add(new ResolveConstantConcatenation); transformations.Add(new ResolveConstantConcatenation);
RunGraphTransformations(model, "general graph transformations", RunGraphTransformations(model, "general graph transformations",
transformations); transformations);
if (output_is_tflite_quantized) { if (quantize_output) {
RunGraphTransformations(model, "pre-quantization graph transformations", RunGraphTransformations(model, "pre-quantization graph transformations",
{new HardcodeMinMax, new DropFakeQuant}); {new HardcodeMinMax, new DropFakeQuant});
} }
if (output_is_tflite_quantized) { if (quantize_output) {
if (toco_flags.has_default_ranges_min() && if (toco_flags.has_default_ranges_min() &&
toco_flags.has_default_ranges_max()) { toco_flags.has_default_ranges_max()) {
UseDefaultMinMaxRangeValues(model, toco_flags.default_ranges_min(), UseDefaultMinMaxRangeValues(model, toco_flags.default_ranges_min(),
@ -232,7 +255,7 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
CheckUnsupportedOperations(*model); CheckUnsupportedOperations(*model);
} }
if (output_is_tflite) { if (SupportsPreallocatedWorkspace(output_format)) {
AllocateTransientArrays(model, kDefaultTransientDataAlignment); AllocateTransientArrays(model, kDefaultTransientDataAlignment);
LogDump(kLogLevelModelChanged, "AFTER ALLOCATION", *model); LogDump(kLogLevelModelChanged, "AFTER ALLOCATION", *model);
} }

View File

@ -294,6 +294,7 @@ void LogArray(int log_level, const Model& model, const string& name) {
VLOG(log_level) << "Array: " << name; VLOG(log_level) << "Array: " << name;
switch (array.data_type) { switch (array.data_type) {
case ArrayDataType::kNone: case ArrayDataType::kNone:
VLOG(log_level) << " Data type:";
break; break;
case ArrayDataType::kFloat: case ArrayDataType::kFloat:
VLOG(log_level) << " Data type: kFloat"; VLOG(log_level) << " Data type: kFloat";
@ -309,6 +310,24 @@ void LogArray(int log_level, const Model& model, const string& name) {
<< static_cast<int>(array.data_type) << ")"; << static_cast<int>(array.data_type) << ")";
break; break;
} }
switch (array.final_data_type) {
case ArrayDataType::kNone:
VLOG(log_level) << " Final type:";
break;
case ArrayDataType::kFloat:
VLOG(log_level) << " Final type: kFloat";
break;
case ArrayDataType::kInt32:
VLOG(log_level) << " Final type: kInt32";
break;
case ArrayDataType::kUint8:
VLOG(log_level) << " Final type: kUint8";
break;
default:
VLOG(log_level) << " Final type: other (numerical value: "
<< static_cast<int>(array.data_type) << ")";
break;
}
if (array.buffer) { if (array.buffer) {
VLOG(log_level) << " Constant Buffer"; VLOG(log_level) << " Constant Buffer";
} }
@ -1016,7 +1035,6 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) {
} }
RESOLVE_MODEL_FLAG(variable_batch) RESOLVE_MODEL_FLAG(variable_batch)
RESOLVE_MODEL_FLAG(drop_control_dependency)
#undef RESOLVE_MODEL_FLAG #undef RESOLVE_MODEL_FLAG
@ -1044,12 +1062,6 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) {
"--output_arrays flag must be given on the command-line."; "--output_arrays flag must be given on the command-line.";
for (const auto& input_array_proto : model->flags.input_arrays()) { for (const auto& input_array_proto : model->flags.input_arrays()) {
QCHECK(!input_array_proto.shape().empty())
<< "This model does not have shape defined for input array "
<< input_array_proto.name()
<< ", so one must be specified by a non-empty --input_shape "
"command-line flag.";
auto& input_array = model->GetOrCreateArray(input_array_proto.name()); auto& input_array = model->GetOrCreateArray(input_array_proto.name());
if (input_array_proto.has_data_type()) { if (input_array_proto.has_data_type()) {
const ArrayDataType specified_type = const ArrayDataType specified_type =
@ -1072,6 +1084,14 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) {
input_array.data_type = ArrayDataType::kFloat; input_array.data_type = ArrayDataType::kFloat;
} }
if (!input_array.has_shape()) {
QCHECK(!input_array_proto.shape().empty())
<< "This model does not have shape defined for input array "
<< input_array_proto.name()
<< ", so one must be specified by a non-empty --input_shape "
"command-line flag.";
}
// Compare/merge the model->flags describing the input_shape with // Compare/merge the model->flags describing the input_shape with
// the actual input array's shape. // the actual input array's shape.
auto& input_array_dims = *input_array.mutable_shape()->mutable_dims(); auto& input_array_dims = *input_array.mutable_shape()->mutable_dims();
@ -1563,7 +1583,11 @@ void CheckFinalDataTypesSatisfied(const Model& model) {
for (const auto& array_entry : model.arrays) { for (const auto& array_entry : model.arrays) {
const auto& array = *array_entry.second; const auto& array = *array_entry.second;
if (array.final_data_type != ArrayDataType::kNone) { if (array.final_data_type != ArrayDataType::kNone) {
CHECK(array.final_data_type == array.data_type); CHECK(array.final_data_type == array.data_type)
<< "Array \"" << array_entry.first
<< "\" has mis-matching actual and final data types ("
<< static_cast<int>(array.data_type) << ","
<< static_cast<int>(array.final_data_type) << ").";
} }
} }
} }

View File

@ -26,7 +26,6 @@ from tensorflow.contrib.opt.python.training.lazy_adam_optimizer import *
from tensorflow.contrib.opt.python.training.moving_average_optimizer import * from tensorflow.contrib.opt.python.training.moving_average_optimizer import *
from tensorflow.contrib.opt.python.training.multitask_optimizer_wrapper import * from tensorflow.contrib.opt.python.training.multitask_optimizer_wrapper import *
from tensorflow.contrib.opt.python.training.nadam_optimizer import * from tensorflow.contrib.opt.python.training.nadam_optimizer import *
from tensorflow.contrib.opt.python.training.nadam_optimizer import *
from tensorflow.contrib.opt.python.training.powersign import * from tensorflow.contrib.opt.python.training.powersign import *
from tensorflow.contrib.opt.python.training.variable_clipping_optimizer import * from tensorflow.contrib.opt.python.training.variable_clipping_optimizer import *
# pylint: enable=wildcard-import # pylint: enable=wildcard-import
@ -35,12 +34,18 @@ from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [ _allowed_symbols = [
'PowerSignOptimizer', 'AddSignOptimizer' 'PowerSignOptimizer',
'AddSignOptimizer'
'DelayCompensatedGradientDescentOptimizer', 'DelayCompensatedGradientDescentOptimizer',
'DropStaleGradientOptimizer', 'ExternalOptimizerInterface', 'DropStaleGradientOptimizer',
'LazyAdamOptimizer', 'NadamOptimizer', 'MovingAverageOptimizer', 'ExternalOptimizerInterface',
'ScipyOptimizerInterface', 'VariableClippingOptimizer', 'LazyAdamOptimizer',
'MultitaskOptimizerWrapper', 'clip_gradients_by_global_norm', 'NadamOptimizer',
'MovingAverageOptimizer',
'ScipyOptimizerInterface',
'VariableClippingOptimizer',
'MultitaskOptimizerWrapper',
'clip_gradients_by_global_norm',
] ]
remove_undocumented(__name__, _allowed_symbols) remove_undocumented(__name__, _allowed_symbols)

View File

@ -12,9 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""An optimizer wrapper for stateful optimizers with multitask loss."""
"""An optimizer wrapper that ensures correct behaviour
of stateful optimizers with multitask loss."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
@ -30,26 +28,27 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.training import optimizer from tensorflow.python.training import optimizer
__all__ = ["MultitaskOptimizerWrapper", __all__ = ['MultitaskOptimizerWrapper', 'clip_gradients_by_global_norm']
"clip_gradients_by_global_norm"]
def _is_all_zeros(grad): def _is_all_zeros(grad):
all_zeros = math_ops.equal(math_ops.count_nonzero(grad), 0) all_zeros = math_ops.equal(math_ops.count_nonzero(grad), 0)
return all_zeros return all_zeros
def _get_wrapper(fn, opt): def _get_wrapper(fn, opt):
def wrapper(self, grad, *args, **kwargs): # pylint: disable=unused-argument def wrapper(self, grad, *args, **kwargs): # pylint: disable=unused-argument
all_zeros = _is_all_zeros(grad) all_zeros = _is_all_zeros(grad)
return control_flow_ops.cond( return control_flow_ops.cond(all_zeros, control_flow_ops.no_op,
all_zeros,
control_flow_ops.no_op,
lambda: fn(grad, *args, **kwargs)) lambda: fn(grad, *args, **kwargs))
wrapper = types.MethodType(wrapper, opt) wrapper = types.MethodType(wrapper, opt)
return wrapper return wrapper
class MultitaskOptimizerWrapper(object): class MultitaskOptimizerWrapper(object):
"""Optimizer wrapper that ensures that """Optimizer wrapper making all-zero gradients harmless.
all-zero gradients don't affect the optimizer state.
This might be useful when a multi-task loss is used, This might be useful when a multi-task loss is used,
and some components of the loss might be and some components of the loss might be
@ -88,20 +87,20 @@ class MultitaskOptimizerWrapper(object):
gradvars_clipped, global_step=batch) gradvars_clipped, global_step=batch)
``` ```
""" """
def __init__(self, opt): def __init__(self, opt):
""" """Constructor.
Args: Args:
opt: an instance of a class that implements tf.train.Optimizer. opt: an instance of a class that implements tf.train.Optimizer.
""" """
if not isinstance(opt, optimizer.Optimizer): if not isinstance(opt, optimizer.Optimizer):
raise TypeError( raise TypeError(
"Supplied optimizer must be an instance of tf.train.Optimizer") 'Supplied optimizer must be an instance of tf.train.Optimizer')
self._opt = opt self._opt = opt
overriden_methods = ('_apply_dense', overridden_methods = ('_apply_dense', '_resource_apply_dense',
'_resource_apply_dense', '_apply_sparse', '_resource_apply_sparse')
'_apply_sparse', for name in overridden_methods:
'_resource_apply_sparse')
for name in overriden_methods:
fn = getattr(self._opt, name) fn = getattr(self._opt, name)
wrapper = _get_wrapper(fn, self._opt) wrapper = _get_wrapper(fn, self._opt)
setattr(self._opt, name, wrapper) setattr(self._opt, name, wrapper)
@ -112,6 +111,7 @@ class MultitaskOptimizerWrapper(object):
def clip_gradients_by_global_norm(gradients_variables, clip_norm=20.): def clip_gradients_by_global_norm(gradients_variables, clip_norm=20.):
"""Clips gradients of a multitask loss by their global norm. """Clips gradients of a multitask loss by their global norm.
Ignores all-zero tensors when computing the global norm. Ignores all-zero tensors when computing the global norm.
Args: Args:
@ -123,16 +123,18 @@ def clip_gradients_by_global_norm(gradients_variables, clip_norm=20.):
fixed_global_norm: A 0-D (scalar) Tensor representing the global norm. fixed_global_norm: A 0-D (scalar) Tensor representing the global norm.
""" """
gradients, variables = six.moves.zip(*gradients_variables) gradients, variables = six.moves.zip(*gradients_variables)
def _replace_nonexisting_grad(grad): def _replace_nonexisting_grad(grad):
if grad is None: if grad is None:
return grad return grad
all_zeros = _is_all_zeros(grad) all_zeros = _is_all_zeros(grad)
return control_flow_ops.cond(all_zeros, return control_flow_ops.cond(
lambda: array_ops.zeros( all_zeros,
[], dtype=dtypes.as_dtype(grad.dtype)), lambda: array_ops.zeros([], dtype=dtypes.as_dtype(grad.dtype)),
lambda: grad) lambda: grad)
nonzero_gradients = [_replace_nonexisting_grad(g) for g in gradients] nonzero_gradients = [_replace_nonexisting_grad(g) for g in gradients]
fixed_global_norm = clip_ops.global_norm(nonzero_gradients) fixed_global_norm = clip_ops.global_norm(nonzero_gradients)
gradients, _ = clip_ops.clip_by_global_norm(gradients, clip_norm, gradients, _ = clip_ops.clip_by_global_norm(
use_norm=fixed_global_norm) gradients, clip_norm, use_norm=fixed_global_norm)
return list(six.moves.zip(gradients, variables)), fixed_global_norm return list(six.moves.zip(gradients, variables)), fixed_global_norm

View File

@ -18,6 +18,9 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import numpy as np
import six
from tensorflow.contrib.opt.python.training import multitask_optimizer_wrapper from tensorflow.contrib.opt.python.training import multitask_optimizer_wrapper
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
@ -25,13 +28,11 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.training import momentum from tensorflow.python.training import momentum
import numpy as np
import six
class MultitaskOptimizerWrapperTest(test.TestCase): class MultitaskOptimizerWrapperTest(test.TestCase):
"""Tests for the multitask optimizer wrapper.
""" """
Tests for the multitask optimizer wrapper.
"""
def testWrapper(self): def testWrapper(self):
with self.test_session(): with self.test_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtypes.float32) var0 = variables.Variable([1.0, 2.0], dtype=dtypes.float32)
@ -39,12 +40,10 @@ class MultitaskOptimizerWrapperTest(test.TestCase):
grads0 = constant_op.constant([0.1, 0.1], dtype=dtypes.float32) grads0 = constant_op.constant([0.1, 0.1], dtype=dtypes.float32)
grads1 = constant_op.constant([0.01, 0.01], dtype=dtypes.float32) grads1 = constant_op.constant([0.01, 0.01], dtype=dtypes.float32)
grads_allzero = constant_op.constant([0.0, 0.0], dtype=dtypes.float32) grads_allzero = constant_op.constant([0.0, 0.0], dtype=dtypes.float32)
mom_opt_impl = momentum.MomentumOptimizer( mom_opt_impl = momentum.MomentumOptimizer(learning_rate=2.0, momentum=0.9)
learning_rate=2.0, momentum=0.9)
mom_opt = multitask_optimizer_wrapper.MultitaskOptimizerWrapper( mom_opt = multitask_optimizer_wrapper.MultitaskOptimizerWrapper(
mom_opt_impl) mom_opt_impl)
mom_update = mom_opt.apply_gradients( mom_update = mom_opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
zip([grads0, grads1], [var0, var1]))
mom_update_partial = mom_opt.apply_gradients( mom_update_partial = mom_opt.apply_gradients(
zip([grads_allzero, grads1], [var0, var1])) zip([grads_allzero, grads1], [var0, var1]))
mom_update_no_action = mom_opt.apply_gradients( mom_update_no_action = mom_opt.apply_gradients(
@ -63,14 +62,13 @@ class MultitaskOptimizerWrapperTest(test.TestCase):
# Step 1: normal momentum update. # Step 1: normal momentum update.
self.evaluate(mom_update) self.evaluate(mom_update)
# Check that the momentum accumulators have been updated. # Check that the momentum accumulators have been updated.
self.assertAllCloseAccordingToType(np.array([0.1, 0.1]), self.assertAllCloseAccordingToType(
self.evaluate(slot0)) np.array([0.1, 0.1]), self.evaluate(slot0))
self.assertAllCloseAccordingToType(np.array([0.01, 0.01]), self.assertAllCloseAccordingToType(
self.evaluate(slot1)) np.array([0.01, 0.01]), self.evaluate(slot1))
# Check that the parameters have been updated. # Check that the parameters have been updated.
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]), np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]), self.evaluate(var0))
self.evaluate(var0))
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]), np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]),
self.evaluate(var1)) self.evaluate(var1))
@ -78,8 +76,8 @@ class MultitaskOptimizerWrapperTest(test.TestCase):
# Step 2: momentum update that changes only slot1 but not slot0. # Step 2: momentum update that changes only slot1 but not slot0.
self.evaluate(mom_update_partial) self.evaluate(mom_update_partial)
# Check that only the relevant momentum accumulator has been updated. # Check that only the relevant momentum accumulator has been updated.
self.assertAllCloseAccordingToType(np.array([0.1, 0.1]), self.assertAllCloseAccordingToType(
self.evaluate(slot0)) np.array([0.1, 0.1]), self.evaluate(slot0))
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]), np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]),
self.evaluate(slot1)) self.evaluate(slot1))
@ -87,8 +85,8 @@ class MultitaskOptimizerWrapperTest(test.TestCase):
# Step 3: momentum update that does not change anything. # Step 3: momentum update that does not change anything.
self.evaluate(mom_update_no_action) self.evaluate(mom_update_no_action)
# Check that the momentum accumulators have *NOT* been updated. # Check that the momentum accumulators have *NOT* been updated.
self.assertAllCloseAccordingToType(np.array([0.1, 0.1]), self.assertAllCloseAccordingToType(
self.evaluate(slot0)) np.array([0.1, 0.1]), self.evaluate(slot0))
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]), np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]),
self.evaluate(slot1)) self.evaluate(slot1))
@ -105,8 +103,9 @@ class MultitaskOptimizerWrapperTest(test.TestCase):
grads3 = None grads3 = None
varlist = [var0, var1, var2, var3] varlist = [var0, var1, var2, var3]
gradients = [grads0, grads1, grads2, grads3] gradients = [grads0, grads1, grads2, grads3]
clipped_gradvars, global_norm = multitask_optimizer_wrapper.clip_gradients_by_global_norm( clipped_gradvars, global_norm = (
six.moves.zip(gradients, varlist), clip_norm=1.0) multitask_optimizer_wrapper.clip_gradients_by_global_norm(
six.moves.zip(gradients, varlist), clip_norm=1.0))
clipped_grads = list(six.moves.zip(*clipped_gradvars))[0] clipped_grads = list(six.moves.zip(*clipped_gradvars))[0]
reference_global_norm = np.sqrt(np.sum(np.square([10.0, 15.0, 0.0, 5.0]))) reference_global_norm = np.sqrt(np.sum(np.square([10.0, 15.0, 0.0, 5.0])))
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
@ -115,5 +114,6 @@ class MultitaskOptimizerWrapperTest(test.TestCase):
self.evaluate(clipped_grads[2]), np.array([0., 0.])) self.evaluate(clipped_grads[2]), np.array([0., 0.]))
self.assertEqual(clipped_grads[3], None) self.assertEqual(clipped_grads[3], None)
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()

View File

@ -24,6 +24,7 @@ import numpy as np
from tensorflow.contrib import rnn as contrib_rnn from tensorflow.contrib import rnn as contrib_rnn
from tensorflow.contrib.rnn.python.ops import core_rnn_cell from tensorflow.contrib.rnn.python.ops import core_rnn_cell
from tensorflow.contrib.rnn.python.ops import rnn_cell as contrib_rnn_cell
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
@ -382,7 +383,8 @@ class RNNCellTest(test.TestCase):
norm_shift=0.0) norm_shift=0.0)
g, out_m = cell(x, state) g, out_m = cell(x, state)
sess.run([variables_lib.global_variables_initializer()]) sess.run([variables_lib.global_variables_initializer()])
res = sess.run([g, out_m], { res = sess.run(
[g, out_m], {
x.name: np.ones((batch_size, input_size)), x.name: np.ones((batch_size, input_size)),
c.name: 0.1 * np.ones((batch_size, num_units)), c.name: 0.1 * np.ones((batch_size, num_units)),
h.name: 0.1 * np.ones((batch_size, num_proj)) h.name: 0.1 * np.ones((batch_size, num_proj))

View File

@ -996,24 +996,17 @@ class RNNCellTest(test.TestCase):
output, state = cell(x, hidden) output, state = cell(x, hidden)
sess.run([variables.global_variables_initializer()]) sess.run([variables.global_variables_initializer()])
res = sess.run([output, state], { res = sess.run(
[output, state], {
hidden[0].name: hidden[0].name:
np.array([[[[[1.],[1.]], np.array([[[[[1.], [1.]], [[1.], [1.]]], [[[1.], [1.]], [[
[[1.],[1.]]], 1.
[[[1.],[1.]], ], [1.]]]], [[[[2.], [2.]], [[2.], [2.]]],
[[1.],[1.]]]], [[[2.], [2.]], [[2.], [2.]]]]]),
[[[[2.],[2.]],
[[2.],[2.]]],
[[[2.],[2.]],
[[2.],[2.]]]]]),
x.name: x.name:
np.array([[[[[1.],[1.]], np.array([[[[[1.], [1.]], [[1.], [1.]]], [[[1.], [1.]], [[
[[1.],[1.]]], 1.
[[[1.],[1.]], ], [1.]]]], [[[[2.], [2.]], [[2.], [2.]]], [[[2.], [2.]],
[[1.],[1.]]]],
[[[[2.],[2.]],
[[2.],[2.]]],
[[[2.],[2.]],
[[2.], [2.]]]]]) [[2.], [2.]]]]])
}) })
# This is a smoke test, making sure expected values are unchanged. # This is a smoke test, making sure expected values are unchanged.
@ -1276,10 +1269,8 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
self.assertAllClose(res[2].c, expected_c1, 1e-5) self.assertAllClose(res[2].c, expected_c1, 1e-5)
self.assertAllClose(res[2].h, expected_h1, 1e-5) self.assertAllClose(res[2].h, expected_h1, 1e-5)
def testBasicLSTMCellWithStateTupleLayerNorm(self): def testBasicLSTMCellWithStateTupleLayerNorm(self):
"""The results of LSTMCell and LayerNormBasicLSTMCell """The results of LSTMCell and LayerNormBasicLSTMCell should be the same."""
should be same. """
with self.test_session() as sess: with self.test_session() as sess:
with variable_scope.variable_scope( with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)): "root", initializer=init_ops.constant_initializer(0.5)):
@ -1290,15 +1281,15 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
c1 = array_ops.zeros([1, 2]) c1 = array_ops.zeros([1, 2])
h1 = array_ops.zeros([1, 2]) h1 = array_ops.zeros([1, 2])
state1 = rnn_cell_impl.LSTMStateTuple(c1, h1) state1 = rnn_cell_impl.LSTMStateTuple(c1, h1)
cell = rnn_cell_impl.MultiRNNCell( cell = rnn_cell_impl.MultiRNNCell([
[contrib_rnn_cell.LayerNormLSTMCell( contrib_rnn_cell.LayerNormLSTMCell(
2, 2, layer_norm=True, norm_gain=1.0, norm_shift=0.0)
layer_norm=True, for _ in range(2)
norm_gain=1.0, ])
norm_shift=0.0) for _ in range(2)])
h, (s0, s1) = cell(x, (state0, state1)) h, (s0, s1) = cell(x, (state0, state1))
sess.run([variables.global_variables_initializer()]) sess.run([variables.global_variables_initializer()])
res = sess.run([h, s0, s1], { res = sess.run(
[h, s0, s1], {
x.name: np.array([[1., 1.]]), x.name: np.array([[1., 1.]]),
c0.name: 0.1 * np.asarray([[0, 1]]), c0.name: 0.1 * np.asarray([[0, 1]]),
h0.name: 0.1 * np.asarray([[2, 3]]), h0.name: 0.1 * np.asarray([[2, 3]]),

View File

@ -115,7 +115,6 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
The class uses optional peep-hole connections, and an optional projection The class uses optional peep-hole connections, and an optional projection
layer. layer.
Layer normalization implementation is based on: Layer normalization implementation is based on:
https://arxiv.org/abs/1607.06450. https://arxiv.org/abs/1607.06450.
@ -127,12 +126,21 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
""" """
def __init__(self, num_units, use_peepholes=False, def __init__(self,
initializer=None, num_proj=None, proj_clip=None, num_units,
num_unit_shards=1, num_proj_shards=1, use_peepholes=False,
forget_bias=1.0, state_is_tuple=True, initializer=None,
activation=math_ops.tanh, reuse=None, num_proj=None,
layer_norm=False, norm_gain=1.0, norm_shift=0.0): proj_clip=None,
num_unit_shards=1,
num_proj_shards=1,
forget_bias=1.0,
state_is_tuple=True,
activation=math_ops.tanh,
reuse=None,
layer_norm=False,
norm_gain=1.0,
norm_shift=0.0):
"""Initialize the parameters for an LSTM cell. """Initialize the parameters for an LSTM cell.
Args: Args:
@ -164,8 +172,6 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
`layer_norm` has been set to `False`, this argument will be ignored. `layer_norm` has been set to `False`, this argument will be ignored.
norm_shift: float, The layer normalization shift initial value. If norm_shift: float, The layer normalization shift initial value. If
`layer_norm` has been set to `False`, this argument will be ignored. `layer_norm` has been set to `False`, this argument will be ignored.
""" """
super(CoupledInputForgetGateLSTMCell, self).__init__(_reuse=reuse) super(CoupledInputForgetGateLSTMCell, self).__init__(_reuse=reuse)
if not state_is_tuple: if not state_is_tuple:
@ -2049,8 +2055,8 @@ class ConvLSTMCell(rnn_cell_impl.RNNCell):
if self._skip_connection: if self._skip_connection:
self._total_output_channels += self._input_shape[-1] self._total_output_channels += self._input_shape[-1]
state_size = tensor_shape.TensorShape(self._input_shape[:-1] state_size = tensor_shape.TensorShape(
+ [self._output_channels]) self._input_shape[:-1] + [self._output_channels])
self._state_size = rnn_cell_impl.LSTMStateTuple(state_size, state_size) self._state_size = rnn_cell_impl.LSTMStateTuple(state_size, state_size)
self._output_size = tensor_shape.TensorShape(self._input_shape[:-1] self._output_size = tensor_shape.TensorShape(self._input_shape[:-1]
+ [self._total_output_channels]) + [self._total_output_channels])
@ -2110,11 +2116,8 @@ class Conv3DLSTMCell(ConvLSTMCell):
"""Construct Conv3DLSTM. See `ConvLSTMCell` for more details.""" """Construct Conv3DLSTM. See `ConvLSTMCell` for more details."""
super(Conv3DLSTMCell, self).__init__(conv_ndims=3, **kwargs) super(Conv3DLSTMCell, self).__init__(conv_ndims=3, **kwargs)
def _conv(args,
filter_size, def _conv(args, filter_size, num_features, bias, bias_start=0.0):
num_features,
bias,
bias_start=0.0):
"""convolution: """convolution:
Args: Args:
args: a Tensor or a list of Tensors of dimension 3D, 4D or 5D, args: a Tensor or a list of Tensors of dimension 3D, 4D or 5D,
@ -2391,12 +2394,19 @@ class LayerNormLSTMCell(rnn_cell_impl.RNNCell):
""" """
def __init__(self, num_units, def __init__(self,
use_peepholes=False, cell_clip=None, num_units,
initializer=None, num_proj=None, proj_clip=None, use_peepholes=False,
cell_clip=None,
initializer=None,
num_proj=None,
proj_clip=None,
forget_bias=1.0, forget_bias=1.0,
activation=None, layer_norm=False, activation=None,
norm_gain=1.0, norm_shift=0.0, reuse=None): layer_norm=False,
norm_gain=1.0,
norm_shift=0.0,
reuse=None):
"""Initialize the parameters for an LSTM cell. """Initialize the parameters for an LSTM cell.
Args: Args:
@ -2457,7 +2467,6 @@ class LayerNormLSTMCell(rnn_cell_impl.RNNCell):
def output_size(self): def output_size(self):
return self._output_size return self._output_size
def _linear(self, def _linear(self,
args, args,
output_size, output_size,
@ -2521,9 +2530,7 @@ class LayerNormLSTMCell(rnn_cell_impl.RNNCell):
if bias_initializer is None: if bias_initializer is None:
bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype) bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype)
biases = vs.get_variable( biases = vs.get_variable(
"bias", [output_size], "bias", [output_size], dtype=dtype, initializer=bias_initializer)
dtype=dtype,
initializer=bias_initializer)
if not layer_norm: if not layer_norm:
res = nn_ops.bias_add(res, biases) res = nn_ops.bias_add(res, biases)
@ -2554,7 +2561,6 @@ class LayerNormLSTMCell(rnn_cell_impl.RNNCell):
ValueError: If input size cannot be inferred from inputs via ValueError: If input size cannot be inferred from inputs via
static shape inference. static shape inference.
""" """
num_proj = self._num_units if self._num_proj is None else self._num_proj
sigmoid = math_ops.sigmoid sigmoid = math_ops.sigmoid
(c_prev, m_prev) = state (c_prev, m_prev) = state
@ -2567,8 +2573,12 @@ class LayerNormLSTMCell(rnn_cell_impl.RNNCell):
with vs.variable_scope(scope, initializer=self._initializer) as unit_scope: with vs.variable_scope(scope, initializer=self._initializer) as unit_scope:
# i = input_gate, j = new_input, f = forget_gate, o = output_gate # i = input_gate, j = new_input, f = forget_gate, o = output_gate
lstm_matrix = self._linear([inputs, m_prev], 4 * self._num_units, bias=True, lstm_matrix = self._linear(
bias_initializer=None, layer_norm=self._layer_norm) [inputs, m_prev],
4 * self._num_units,
bias=True,
bias_initializer=None,
layer_norm=self._layer_norm)
i, j, f, o = array_ops.split( i, j, f, o = array_ops.split(
value=lstm_matrix, num_or_size_splits=4, axis=1) value=lstm_matrix, num_or_size_splits=4, axis=1)
@ -2580,7 +2590,7 @@ class LayerNormLSTMCell(rnn_cell_impl.RNNCell):
# Diagonal connections # Diagonal connections
if self._use_peepholes: if self._use_peepholes:
with vs.variable_scope(unit_scope) as projection_scope: with vs.variable_scope(unit_scope):
w_f_diag = vs.get_variable( w_f_diag = vs.get_variable(
"w_f_diag", shape=[self._num_units], dtype=dtype) "w_f_diag", shape=[self._num_units], dtype=dtype)
w_i_diag = vs.get_variable( w_i_diag = vs.get_variable(
@ -2589,11 +2599,13 @@ class LayerNormLSTMCell(rnn_cell_impl.RNNCell):
"w_o_diag", shape=[self._num_units], dtype=dtype) "w_o_diag", shape=[self._num_units], dtype=dtype)
if self._use_peepholes: if self._use_peepholes:
c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev + c = (
sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
sigmoid(i + w_i_diag * c_prev) * self._activation(j)) sigmoid(i + w_i_diag * c_prev) * self._activation(j))
else: else:
c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * c = (
self._activation(j)) sigmoid(f + self._forget_bias) * c_prev +
sigmoid(i) * self._activation(j))
if self._layer_norm: if self._layer_norm:
c = _norm(self._norm_gain, self._norm_shift, c, "state") c = _norm(self._norm_gain, self._norm_shift, c, "state")
@ -2608,7 +2620,7 @@ class LayerNormLSTMCell(rnn_cell_impl.RNNCell):
m = sigmoid(o) * self._activation(c) m = sigmoid(o) * self._activation(c)
if self._num_proj is not None: if self._num_proj is not None:
with vs.variable_scope("projection") as proj_scope: with vs.variable_scope("projection"):
m = self._linear(m, self._num_proj, bias=False) m = self._linear(m, self._num_proj, bias=False)
if self._proj_clip is not None: if self._proj_clip is not None:

View File

@ -192,7 +192,8 @@ class _BaseAttentionMechanism(AttentionMechanism):
raise TypeError("probability_fn must be callable, saw type: %s" % raise TypeError("probability_fn must be callable, saw type: %s" %
type(probability_fn).__name__) type(probability_fn).__name__)
if score_mask_value is None: if score_mask_value is None:
score_mask_value = dtypes.as_dtype(self._memory_layer.dtype).as_numpy_dtype(-np.inf) score_mask_value = dtypes.as_dtype(
self._memory_layer.dtype).as_numpy_dtype(-np.inf)
self._probability_fn = lambda score, prev: ( # pylint:disable=g-long-lambda self._probability_fn = lambda score, prev: ( # pylint:disable=g-long-lambda
probability_fn( probability_fn(
_maybe_mask_score(score, memory_sequence_length, score_mask_value), _maybe_mask_score(score, memory_sequence_length, score_mask_value),
@ -1145,7 +1146,9 @@ class AttentionWrapper(rnn_cell_impl.RNNCell):
% (len(attention_layer_sizes), len(attention_mechanisms))) % (len(attention_layer_sizes), len(attention_mechanisms)))
self._attention_layers = tuple( self._attention_layers = tuple(
layers_core.Dense( layers_core.Dense(
attention_layer_size, name="attention_layer", use_bias=False, attention_layer_size,
name="attention_layer",
use_bias=False,
dtype=attention_mechanisms[i].dtype) dtype=attention_mechanisms[i].dtype)
for i, attention_layer_size in enumerate(attention_layer_sizes)) for i, attention_layer_size in enumerate(attention_layer_sizes))
self._attention_layer_size = sum(attention_layer_sizes) self._attention_layer_size = sum(attention_layer_sizes)

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifdef TENSORFLOW_USE_VERBS #ifdef TENSORFLOW_USE_VERBS
#include "tensorflow/contrib/verbs/rdma.h" #include "tensorflow/contrib/verbs/rdma.h"
#include <fcntl.h>
#include <cstdlib> #include <cstdlib>
#include <fcntl.h> #include <fcntl.h>
#include "tensorflow/contrib/verbs/verbs_util.h" #include "tensorflow/contrib/verbs/verbs_util.h"
@ -199,9 +200,7 @@ uint8_t set_port(ibv_context* context) {
// check if port id active // check if port id active
CHECK(port_attr.state == IBV_PORT_ACTIVE) CHECK(port_attr.state == IBV_PORT_ACTIVE)
<< "Selected RDMA_DEVICE_PORT is not active"; << "Selected RDMA_DEVICE_PORT is not active";
} } else { // set default port
// set default port
else {
for (port_index = 1; port_index <= device_att.phys_port_cnt; port_index++) { for (port_index = 1; port_index <= device_att.phys_port_cnt; port_index++) {
rc = ibv_query_port(context, port_index, &port_attr); rc = ibv_query_port(context, port_index, &port_attr);
CHECK(!rc) << "Failed to query the port" << port_index; CHECK(!rc) << "Failed to query the port" << port_index;
@ -269,7 +268,7 @@ bool is_gid_type_roce_v2(ibv_context* context, uint8_t port_num,
// Function to set GID index. // Function to set GID index.
// If the port link is IB, no GID index should be selected. // If the port link is IB, no GID index should be selected.
// If Ethernet but RDMA_GID_INDEX not set gid index that supports // If Ethernet but RDMA_GID_INDEX not set gid index that supports
// RoCE V2 will be chosen(fails if more then one IP is configured) // RoCE V2 will be chosen(fails if more than one IP is configured)
// Args: // Args:
// context - device context // context - device context
// port_num - port number // port_num - port number
@ -374,7 +373,8 @@ enum ibv_mtu set_mtu(uint8_t port_num, ibv_context* context) {
break; break;
default: default:
CHECK(0) << "Error: MTU input value must be one of the following: 256, " CHECK(0) << "Error: MTU input value must be one of the following: 256, "
"512, 1024, 2048, 4096. MTU " << mtu << " is invalid\n"; "512, 1024, 2048, 4096. MTU "
<< mtu << " is invalid\n";
break; break;
} }
CHECK(mtu < port_attr.active_mtu) CHECK(mtu < port_attr.active_mtu)

View File

@ -921,7 +921,7 @@ Status InferenceContext::Add(DimensionHandle first, DimensionOrConstant second,
if (first_value == 0) { if (first_value == 0) {
*out = MakeDim(second); *out = MakeDim(second);
} else if (second_value == 0) { } else if (second_value == 0) {
*out = MakeDim(first); *out = first;
} else if (first_value == kUnknownDim || second_value == kUnknownDim) { } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
*out = UnknownDim(); *out = UnknownDim();
} else { } else {
@ -946,7 +946,7 @@ Status InferenceContext::Subtract(DimensionHandle first,
const int64 second_value = Value(second); const int64 second_value = Value(second);
// Special cases. // Special cases.
if (second_value == 0) { if (second_value == 0) {
*out = MakeDim(first); *out = first;
} else if (first_value == kUnknownDim || second_value == kUnknownDim) { } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
*out = UnknownDim(); *out = UnknownDim();
} else { } else {

View File

@ -455,7 +455,6 @@ class Graph {
// the corresponding NodeDef to reflect the change. // the corresponding NodeDef to reflect the change.
// REQUIRES: The control edge must exist. // REQUIRES: The control edge must exist.
void RemoveControlEdge(const Edge* e); void RemoveControlEdge(const Edge* e);
// Updates the input to a node. The existing edge to `dst` is removed and an // Updates the input to a node. The existing edge to `dst` is removed and an
// edge from `new_src` to `dst` is created. The NodeDef associated with `dst` // edge from `new_src` to `dst` is created. The NodeDef associated with `dst`
// is also updated. // is also updated.

View File

@ -118,11 +118,9 @@ class GraphTest : public ::testing::Test {
LOG(FATAL) << name; LOG(FATAL) << name;
} }
bool ControlEdgeExistsInGraphOrNodeDef(const Node* src, bool ControlEdgeExistsInGraphOrNodeDef(const Node* src, const Node* dst) {
const Node* dst) {
for (const Edge* e : dst->in_edges()) { for (const Edge* e : dst->in_edges()) {
if (e->IsControlEdge() && if (e->IsControlEdge() && e->src() == src &&
e->src() == src &&
e->src_output() == Graph::kControlSlot && e->src_output() == Graph::kControlSlot &&
e->dst_input() == Graph::kControlSlot) { e->dst_input() == Graph::kControlSlot) {
return true; return true;

View File

@ -702,12 +702,16 @@ Status GraphProperties::UpdateShapes(SymbolicShapeRefiner* shape_refiner,
Status GraphProperties::PropagateShapes( Status GraphProperties::PropagateShapes(
SymbolicShapeRefiner* shape_refiner, bool relax, TopoQueue* new_shapes, SymbolicShapeRefiner* shape_refiner, bool relax, TopoQueue* new_shapes,
const std::unordered_map<const Node*, std::unordered_set<const Node*>>& const std::unordered_map<const Node*, std::unordered_set<const Node*>>&
resources) const { resources,
int num_loops) const {
// Limit the number of iterations to prevent infinite loops in the presence of // Limit the number of iterations to prevent infinite loops in the presence of
// incorrect shape functions. The algoritm should converge in at most // incorrect shape functions. The algoritm should converge in at most
// num_nested_loops^2 * max_rank. We approximate max_rank with the constant 4. // num_nested_loops^2 * max_rank. We approximate max_rank with the constant 4.
// The same applies to resources. // The same applies to resources.
const int64 num_loops = new_shapes->size(); VLOG(1) << "Propagating (relax=" << relax << ") " << new_shapes->size()
<< " new shapes through " << num_loops << " loops and "
<< resources.size() << " resources" << std::endl;
const int64 max_loop_length = item_.graph.node_size(); const int64 max_loop_length = item_.graph.node_size();
const int64 max_rank = 4; const int64 max_rank = 4;
const int64 max_loop_iterations = const int64 max_loop_iterations =
@ -721,11 +725,14 @@ Status GraphProperties::PropagateShapes(
while (!new_shapes->empty() && while (!new_shapes->empty() &&
num_loop_iterations++ < max_loop_iterations) { num_loop_iterations++ < max_loop_iterations) {
const Node* n = new_shapes->pop(); const Node* n = new_shapes->pop();
for (const Node* fanout : n->out_nodes()) { for (const Edge* e : n->out_edges()) {
if (!e->IsControlEdge()) {
const Node* fanout = e->dst();
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
UpdateShapes(shape_refiner, relax, fanout, new_shapes)); UpdateShapes(shape_refiner, relax, fanout, new_shapes));
} }
} }
}
for (const auto& resource : resources) { for (const auto& resource : resources) {
// Resources need special handling: since the enqueue nodes are in the // Resources need special handling: since the enqueue nodes are in the
@ -818,6 +825,7 @@ Status GraphProperties::InferStatically() {
std::unordered_map<const Node*, std::unordered_set<const Node*>> resources; std::unordered_map<const Node*, std::unordered_set<const Node*>> resources;
std::unordered_set<const Node*> enter_nodes; std::unordered_set<const Node*> enter_nodes;
std::unordered_set<const Node*> merge_nodes; std::unordered_set<const Node*> merge_nodes;
int num_loops = 0;
for (const Node* const node : graph.nodes()) { for (const Node* const node : graph.nodes()) {
for (int i = 0; i < node->num_inputs(); ++i) { for (int i = 0; i < node->num_inputs(); ++i) {
if (node->input_type(i) == DataType::DT_RESOURCE) { if (node->input_type(i) == DataType::DT_RESOURCE) {
@ -830,6 +838,8 @@ Status GraphProperties::InferStatically() {
enter_nodes.insert(node); enter_nodes.insert(node);
} else if (node->IsMerge()) { } else if (node->IsMerge()) {
merge_nodes.insert(node); merge_nodes.insert(node);
} else if (node->IsNextIteration()) {
++num_loops;
} }
} }
@ -853,7 +863,7 @@ Status GraphProperties::InferStatically() {
} }
// Propagate shapes normally. // Propagate shapes normally.
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
PropagateShapes(&refiner, relax, &new_shapes, resources)); PropagateShapes(&refiner, relax, &new_shapes, resources, num_loops));
} }
// Track shapes globally across the graph. // Track shapes globally across the graph.
@ -906,6 +916,9 @@ Status GraphProperties::InferStatically() {
&input_properties[i]); &input_properties[i]);
} }
for (const auto& edge : node->in_edges()) { for (const auto& edge : node->in_edges()) {
if (edge->IsControlEdge()) {
continue;
}
if (!edge->src()->IsConstant()) { if (!edge->src()->IsConstant()) {
continue; continue;
} }

View File

@ -108,7 +108,8 @@ class GraphProperties {
Status PropagateShapes( Status PropagateShapes(
SymbolicShapeRefiner* shape_refiner, bool relax, TopoQueue* new_shapes, SymbolicShapeRefiner* shape_refiner, bool relax, TopoQueue* new_shapes,
const std::unordered_map<const Node*, std::unordered_set<const Node*>>& const std::unordered_map<const Node*, std::unordered_set<const Node*>>&
resources) const; resources,
int num_loops) const;
}; };
} // end namespace grappler } // end namespace grappler

View File

@ -24,64 +24,40 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace grappler { namespace grappler {
bool IsAdd(const NodeDef& node) { bool IsAdd(const NodeDef& node) { return node.op() == "Add"; }
const auto op = node.op();
return op == "Add";
}
bool IsAddN(const NodeDef& node) { bool IsAddN(const NodeDef& node) { return node.op() == "AddN"; }
const auto op = node.op();
return op == "AddN";
}
bool IsAvgPoolGrad(const NodeDef& node) { bool IsAvgPoolGrad(const NodeDef& node) { return node.op() == "AvgPoolGrad"; }
const auto op = node.op();
return op == "AvgPoolGrad";
}
bool IsBiasAddGrad(const NodeDef& node) { bool IsAssert(const NodeDef& node) { return node.op() == "Assert"; }
const auto op = node.op();
return op == "BiasAddGrad";
}
bool IsConcatOffset(const NodeDef& node) { bool IsBiasAddGrad(const NodeDef& node) { return node.op() == "BiasAddGrad"; }
const auto op = node.op();
return op == "ConcatOffset";
}
bool IsConstant(const NodeDef& node) { bool IsConcatOffset(const NodeDef& node) { return node.op() == "ConcatOffset"; }
const auto op = node.op();
return op == "Const";
}
bool IsConv2D(const NodeDef& node) { bool IsConstant(const NodeDef& node) { return node.op() == "Const"; }
const auto op = node.op();
return op == "Conv2D"; bool IsConv2D(const NodeDef& node) { return node.op() == "Conv2D"; }
}
bool IsConv2DBackpropFilter(const NodeDef& node) { bool IsConv2DBackpropFilter(const NodeDef& node) {
const auto op = node.op(); return node.op() == "Conv2DBackpropFilter";
return op == "Conv2DBackpropFilter";
} }
bool IsConv2DBackpropInput(const NodeDef& node) { bool IsConv2DBackpropInput(const NodeDef& node) {
const auto op = node.op(); return node.op() == "Conv2DBackpropInput";
return op == "Conv2DBackpropInput";
} }
bool IsDepthwiseConv2dNative(const NodeDef& node) { bool IsDepthwiseConv2dNative(const NodeDef& node) {
const auto op = node.op(); return node.op() == "DepthwiseConv2dNative";
return op == "DepthwiseConv2dNative";
} }
bool IsDepthwiseConv2dNativeBackpropFilter(const NodeDef& node) { bool IsDepthwiseConv2dNativeBackpropFilter(const NodeDef& node) {
const auto op = node.op(); return node.op() == "DepthwiseConv2dNativeBackpropFilter";
return op == "DepthwiseConv2dNativeBackpropFilter";
} }
bool IsDepthwiseConv2dNativeBackpropInput(const NodeDef& node) { bool IsDepthwiseConv2dNativeBackpropInput(const NodeDef& node) {
const auto op = node.op(); return node.op() == "DepthwiseConv2dNativeBackpropInput";
return op == "DepthwiseConv2dNativeBackpropInput";
} }
bool IsDequeueOp(const NodeDef& node) { bool IsDequeueOp(const NodeDef& node) {
@ -101,14 +77,10 @@ bool IsExit(const NodeDef& node) {
return op == "Exit" || op == "RefExit"; return op == "Exit" || op == "RefExit";
} }
bool IsFloorMod(const NodeDef& node) { bool IsFloorMod(const NodeDef& node) { return node.op() == "FloorMod"; }
const auto& op = node.op();
return op == "FloorMod";
}
bool IsFusedBatchNormGradV1(const NodeDef& node) { bool IsFusedBatchNormGradV1(const NodeDef& node) {
const auto& op = node.op(); return node.op() == "FusedBatchNormGrad";
return op == "FusedBatchNormGrad";
} }
bool IsIdentity(const NodeDef& node) { bool IsIdentity(const NodeDef& node) {
@ -121,25 +93,16 @@ bool IsMerge(const NodeDef& node) {
return op == "Merge" || op == "RefMerge"; return op == "Merge" || op == "RefMerge";
} }
bool IsMul(const NodeDef& node) { bool IsMul(const NodeDef& node) { return node.op() == "Mul"; }
const auto op = node.op();
return op == "Mul";
}
bool IsNoOp(const NodeDef& node) { bool IsNoOp(const NodeDef& node) { return node.op() == "NoOp"; }
const auto op = node.op();
return op == "NoOp";
}
bool IsNextIteration(const NodeDef& node) { bool IsNextIteration(const NodeDef& node) {
const auto& op = node.op(); const auto& op = node.op();
return op == "NextIteration" || op == "RefNextIteration"; return op == "NextIteration" || op == "RefNextIteration";
} }
bool IsPad(const NodeDef& node) { bool IsPad(const NodeDef& node) { return node.op() == "Pad"; }
const auto op = node.op();
return op == "Pad";
}
bool IsPlaceholder(const NodeDef& node) { bool IsPlaceholder(const NodeDef& node) {
const auto op = node.op(); const auto op = node.op();
@ -147,20 +110,11 @@ bool IsPlaceholder(const NodeDef& node) {
op == "PlaceholderWithDefault"; op == "PlaceholderWithDefault";
} }
bool IsRealDiv(const NodeDef& node) { bool IsRealDiv(const NodeDef& node) { return node.op() == "RealDiv"; }
const auto op = node.op();
return op == "RealDiv";
}
bool IsReluGrad(const NodeDef& node) { bool IsReluGrad(const NodeDef& node) { return node.op() == "ReluGrad"; }
const auto op = node.op();
return op == "ReluGrad";
}
bool IsRecv(const NodeDef& node) { bool IsRecv(const NodeDef& node) { return node.op() == "_Recv"; }
const auto op = node.op();
return op == "_Recv";
}
bool IsReduction(const NodeDef& node) { bool IsReduction(const NodeDef& node) {
const auto& op = node.op(); const auto& op = node.op();
@ -175,53 +129,34 @@ bool IsRestore(const NodeDef& node) {
node.op() == "RestoreSlice"); node.op() == "RestoreSlice");
} }
bool IsSend(const NodeDef& node) { bool IsSend(const NodeDef& node) { return node.op() == "_Send"; }
const auto op = node.op();
return op == "_Send";
}
bool IsSlice(const NodeDef& node) { bool IsSlice(const NodeDef& node) { return node.op() == "Slice"; }
const auto op = node.op();
return op == "Slice";
}
bool IsSquaredDifference(const NodeDef& node) { bool IsSquaredDifference(const NodeDef& node) {
const auto op = node.op(); return node.op() == "SquaredDifference";
return op == "SquaredDifference";
} }
bool IsSqueeze(const NodeDef& node) { bool IsSqueeze(const NodeDef& node) { return node.op() == "Squeeze"; }
const auto op = node.op();
return op == "Squeeze";
}
bool IsStopGradient(const NodeDef& node) { bool IsStopGradient(const NodeDef& node) {
const auto& op = node.op(); const auto& op = node.op();
return op == "StopGradient" || op == "PreventGradient"; return op == "StopGradient" || op == "PreventGradient";
} }
bool IsSub(const NodeDef& node) { bool IsSub(const NodeDef& node) { return node.op() == "Sub"; }
const auto op = node.op();
return op == "Sub";
}
bool IsSum(const NodeDef& node) { bool IsSum(const NodeDef& node) { return node.op() == "Sum"; }
const auto op = node.op();
return op == "Sum";
}
bool IsSwitch(const NodeDef& node) { bool IsSwitch(const NodeDef& node) {
const auto& op = node.op(); const auto& op = node.op();
return op == "Switch" || op == "RefSwitch"; return op == "Switch" || op == "RefSwitch";
} }
bool IsTranspose(const NodeDef& node) { bool IsTranspose(const NodeDef& node) { return node.op() == "Transpose"; }
const auto op = node.op();
return op == "Transpose";
}
bool IsVariable(const NodeDef& node) { bool IsVariable(const NodeDef& node) {
const auto op = node.op(); const auto& op = node.op();
return op == "Variable" || op == "VariableV2" || op == "AutoReloadVariable" || return op == "Variable" || op == "VariableV2" || op == "AutoReloadVariable" ||
op == "VarHandleOp" || op == "ReadVariableOp"; op == "VarHandleOp" || op == "ReadVariableOp";
} }

View File

@ -25,6 +25,7 @@ namespace grappler {
bool IsAdd(const NodeDef& node); bool IsAdd(const NodeDef& node);
bool IsAddN(const NodeDef& node); bool IsAddN(const NodeDef& node);
bool IsAvgPoolGrad(const NodeDef& node); bool IsAvgPoolGrad(const NodeDef& node);
bool IsAssert(const NodeDef& node);
bool IsBiasAddGrad(const NodeDef& node); bool IsBiasAddGrad(const NodeDef& node);
bool IsConcatOffset(const NodeDef& node); bool IsConcatOffset(const NodeDef& node);
bool IsConstant(const NodeDef& node); bool IsConstant(const NodeDef& node);

View File

@ -448,6 +448,10 @@ bool ArithmeticOptimizer::CanDedup(const NodeDef& node) const {
if (node.device().find("SPU") != string::npos) { if (node.device().find("SPU") != string::npos) {
return false; return false;
} }
// Workaround for Assert mistakenly being labeled as stateful.
if (IsAssert(node)) {
return true;
}
return IsFreeOfSideEffect(node); return IsFreeOfSideEffect(node);
} }

View File

@ -81,6 +81,38 @@ TEST_F(ArithmeticOptimizerTest, OpDedupping) {
EXPECT_EQ("c1", new_mul.input(1)); EXPECT_EQ("c1", new_mul.input(1));
} }
TEST_F(ArithmeticOptimizerTest, OpDeduppingAssertAndCheckNumerics) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output p = ops::Placeholder(s, DT_BOOL, ops::Placeholder::Shape({}));
Output c = ops::Const(s.WithOpName("c"), {3.14, 2.7}, {1, 2});
auto check1 = ops::CheckNumerics(s.WithOpName("check1"), c, "foo");
auto check2 = ops::CheckNumerics(s.WithOpName("check2"), c, "foo");
auto assert1 = ops::Assert(s.WithOpName("assert1"), p, {c});
auto assert2 = ops::Assert(s.WithOpName("assert2"), p, {c});
Output mul = ops::Multiply(s.WithOpName("mul").WithControlDependencies(
{assert1.operation, assert2.operation}),
check1, check2);
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
ArithmeticOptimizer optimizer;
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
// Run the optimizer twice to make sure the rewrite is idempotent.
item.graph.Swap(&output);
status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
EXPECT_EQ(5, output.node_size());
const NodeDef& new_mul = output.node(3);
EXPECT_EQ(4, new_mul.input_size());
EXPECT_EQ("check1", new_mul.input(0));
EXPECT_EQ("check1", new_mul.input(1));
EXPECT_EQ("^assert1", new_mul.input(2));
EXPECT_EQ("^assert1", new_mul.input(3));
}
TEST_F(ArithmeticOptimizerTest, OpDedupCommutative) { TEST_F(ArithmeticOptimizerTest, OpDedupCommutative) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope(); tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output c1 = ops::Const(s.WithOpName("c1"), {1.0f, 2.0f}, {1, 2}); Output c1 = ops::Const(s.WithOpName("c1"), {1.0f, 2.0f}, {1, 2});

View File

@ -1720,6 +1720,7 @@ tf_cuda_cc_tests(
":data_flow", ":data_flow",
":ops_testutil", ":ops_testutil",
":ops_util", ":ops_util",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",

View File

@ -97,8 +97,9 @@ class BincountOp : public OpKernel {
const Tensor& weights_t = ctx->input(2); const Tensor& weights_t = ctx->input(2);
int32 size = size_tensor.scalar<int32>()(); int32 size = size_tensor.scalar<int32>()();
OP_REQUIRES(ctx, size >= 0, errors::InvalidArgument( OP_REQUIRES(
"size (", size, ") must be non-negative")); ctx, size >= 0,
errors::InvalidArgument("size (", size, ") must be non-negative"));
const auto arr = arr_t.flat<int32>(); const auto arr = arr_t.flat<int32>();
const auto weights = weights_t.flat<T>(); const auto weights = weights_t.flat<T>();

View File

@ -16,11 +16,11 @@ limitations under the License.
#ifndef TENSORFLOW_BINCOUNT_OP_H_ #ifndef TENSORFLOW_BINCOUNT_OP_H_
#define TENSORFLOW_BINCOUNT_OP_H_ #define TENSORFLOW_BINCOUNT_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
namespace tensorflow { namespace tensorflow {

View File

@ -17,12 +17,12 @@ limitations under the License.
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "tensorflow/core/kernels/bincount_op.h"
#include "external/cub_archive/cub/device/device_histogram.cuh" #include "external/cub_archive/cub/device/device_histogram.cuh"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/bincount_op.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/cuda_kernel_helper.h" #include "tensorflow/core/util/cuda_kernel_helper.h"
@ -93,8 +93,8 @@ struct BincountFunctor<GPUDevice, T> {
/* num_samples */ num_samples, /* num_samples */ num_samples,
/* stream */ stream); /* stream */ stream);
if (err != cudaSuccess) { if (err != cudaSuccess) {
return errors::Internal("Could not launch HistogramEven: ", return errors::Internal(
cudaGetErrorString(err), "."); "Could not launch HistogramEven: ", cudaGetErrorString(err), ".");
} }
return Status::OK(); return Status::OK();
} }

View File

@ -30,8 +30,8 @@ static Graph* Bincount(int arr_size, int nbins) {
Tensor arr(DT_INT32, TensorShape({arr_size})); Tensor arr(DT_INT32, TensorShape({arr_size}));
arr.flat<int32>() = arr.flat<int32>().setRandom().abs(); arr.flat<int32>() = arr.flat<int32>().setRandom().abs();
Tensor size(DT_INT32, TensorShape({(int32)1})); Tensor size(DT_INT32, TensorShape({static_cast<int32>(1)}));
size.flat<int32>()(0) = (int32)nbins; size.flat<int32>()(0) = static_cast<int32>(nbins);
Tensor weights(DT_INT32, TensorShape({0})); Tensor weights(DT_INT32, TensorShape({0}));

View File

@ -77,8 +77,8 @@ struct BucketizeFunctor<GPUDevice, T> {
TF_RETURN_IF_ERROR(boundaries_array.Finalize()); TF_RETURN_IF_ERROR(boundaries_array.Finalize());
CudaLaunchConfig config = GetCudaLaunchConfig(input.size(), d); CudaLaunchConfig config = GetCudaLaunchConfig(input.size(), d);
BucketizeCustomKernel< BucketizeCustomKernel<T>
T><<<config.block_count, config.thread_per_block, 0, d.stream()>>>( <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
input.size(), input.data(), boundaries_vector.size(), input.size(), input.data(), boundaries_vector.size(),
boundaries_array.data(), output.data()); boundaries_array.data(), output.data());

View File

@ -1101,8 +1101,6 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
bool cudnn_use_autotune_; bool cudnn_use_autotune_;
}; };
#define REGISTER_GPU_KERNEL(T) \ #define REGISTER_GPU_KERNEL(T) \
REGISTER_KERNEL_BUILDER( \ REGISTER_KERNEL_BUILDER( \
Name("Conv3DBackpropInput").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ Name("Conv3DBackpropInput").Device(DEVICE_GPU).TypeConstraint<T>("T"), \

View File

@ -231,7 +231,8 @@ static void CopyOutputBackpropRegion(const DepthwiseArgs& args,
} }
// Pad to vector-register width (if needed). // Pad to vector-register width (if needed).
for (int64 d = 0; d < pad_size; ++d) { for (int64 d = 0; d < pad_size; ++d) {
buffer[buf_base + vectorized_size + scalar_size + d] = static_cast<T>(0); buffer[buf_base + vectorized_size + scalar_size + d] =
static_cast<T>(0);
} }
} }
} }
@ -510,7 +511,8 @@ static void DepthwiseConvBackpropInputReference(const DepthwiseArgs& args,
#if GOOGLE_CUDA #if GOOGLE_CUDA
extern template struct LaunchDepthwiseConvBackpropInputOp<GPUDevice, Eigen::half>; extern template struct LaunchDepthwiseConvBackpropInputOp<GPUDevice,
Eigen::half>;
extern template struct LaunchDepthwiseConvBackpropInputOp<GPUDevice, float>; extern template struct LaunchDepthwiseConvBackpropInputOp<GPUDevice, float>;
extern template struct LaunchDepthwiseConvBackpropInputOp<GPUDevice, double>; extern template struct LaunchDepthwiseConvBackpropInputOp<GPUDevice, double>;
@ -885,7 +887,8 @@ static void DepthwiseConvBackpropFilterReference(const DepthwiseArgs& args,
#if GOOGLE_CUDA #if GOOGLE_CUDA
extern template struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice, Eigen::half>; extern template struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice,
Eigen::half>;
extern template struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice, float>; extern template struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice, float>;
extern template struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice, double>; extern template struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice, double>;

View File

@ -158,7 +158,8 @@ struct DepthwiseFilterPadOp {
} }
// Pad the remainder of output to vector-register boundary. // Pad the remainder of output to vector-register boundary.
for (int64 j = 0; j < pad_size; ++j) { for (int64 j = 0; j < pad_size; ++j) {
padded_filter[output_base + vectorized_size + scalar_size + j] = static_cast<T>(0); padded_filter[output_base + vectorized_size + scalar_size + j] =
static_cast<T>(0);
} }
} }
} }

View File

@ -73,18 +73,22 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
std::move(other_arguments), std::move(other_arguments),
&captured_func)); &captured_func));
*output = new Dataset(input, std::move(captured_func), cycle_length, *output =
new Dataset(ctx, input, func_, std::move(captured_func), cycle_length,
block_length, output_types_, output_shapes_); block_length, output_types_, output_shapes_);
} }
private: private:
class Dataset : public DatasetBase { class Dataset : public GraphDatasetBase {
public: public:
Dataset(const DatasetBase* input, Dataset(OpKernelContext* ctx, const DatasetBase* input,
const NameAttrList& func,
std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length, std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length,
int64 block_length, const DataTypeVector& output_types, int64 block_length, const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes) const std::vector<PartialTensorShape>& output_shapes)
: input_(input), : GraphDatasetBase(ctx),
input_(input),
func_(func),
captured_func_(std::move(captured_func)), captured_func_(std::move(captured_func)),
cycle_length_(cycle_length), cycle_length_(cycle_length),
block_length_(block_length), block_length_(block_length),
@ -110,13 +114,47 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
string DebugString() override { return "InterleaveDatasetOp::Dataset"; } string DebugString() override { return "InterleaveDatasetOp::Dataset"; }
protected:
Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
Node** output) const override {
TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name()));
Node* input_node;
TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_node));
Node* cycle_length_node;
TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node));
Node* block_length_node;
TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node));
DataTypeVector other_arguments_types;
other_arguments_types.reserve(captured_func_->captured_inputs().size());
std::vector<NodeBuilder::NodeOut> other_arguments;
other_arguments.reserve(captured_func_->captured_inputs().size());
for (const Tensor& t : captured_func_->captured_inputs()) {
Node* node;
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
other_arguments.emplace_back(node);
other_arguments_types.emplace_back(t.dtype());
}
AttrValue f;
b->BuildAttrValue(func_, &f);
AttrValue other_arguments_types_attr;
b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
TF_RETURN_IF_ERROR(b->AddDataset(
this,
{{0, input_node}, {2, cycle_length_node}, {3, block_length_node}},
{{1, other_arguments}},
{{"f", f}, {"Targuments", other_arguments_types_attr}}, output));
return Status::OK();
}
private: private:
class Iterator : public DatasetIterator<Dataset> { class Iterator : public DatasetIterator<Dataset> {
public: public:
explicit Iterator(const Params& params) explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params), : DatasetIterator<Dataset>(params),
input_impl_(params.dataset->input_->MakeIterator(params.prefix)), input_impl_(params.dataset->input_->MakeIterator(params.prefix)),
current_elements_(params.dataset->cycle_length_) {} current_elements_(params.dataset->cycle_length_),
args_list_(params.dataset->cycle_length_) {}
void AdvanceToNextInCycle() EXCLUSIVE_LOCKS_REQUIRED(mu_) { void AdvanceToNextInCycle() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
block_index_ = 0; block_index_ = 0;
@ -150,18 +188,19 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
// We have reached the end of the current element, so move // We have reached the end of the current element, so move
// on to the next element in the cycle. // on to the next element in the cycle.
current_elements_[cycle_index_].reset(); current_elements_[cycle_index_].reset();
args_list_[cycle_index_].clear();
--num_open_; --num_open_;
AdvanceToNextInCycle(); AdvanceToNextInCycle();
} else if (!end_of_input_) { } else if (!end_of_input_) {
// Get the next element from the input dataset, and create // Get the next element from the input dataset, and create
// an iterator from it. // an iterator from it.
std::vector<Tensor> args; TF_RETURN_IF_ERROR(input_impl_->GetNext(
TF_RETURN_IF_ERROR( ctx, &args_list_[cycle_index_], &end_of_input_));
input_impl_->GetNext(ctx, &args, &end_of_input_));
if (!end_of_input_) { if (!end_of_input_) {
TF_RETURN_IF_ERROR(dataset::MakeIteratorFromInputElement( TF_RETURN_IF_ERROR(dataset::MakeIteratorFromInputElement(
ctx, args, cycle_index_, dataset()->captured_func_.get(), ctx, args_list_[cycle_index_], cycle_index_,
prefix(), &current_elements_[cycle_index_])); dataset()->captured_func_.get(), prefix(),
&current_elements_[cycle_index_]));
++num_open_; ++num_open_;
} }
} else { } else {
@ -173,11 +212,100 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
return Status::OK(); return Status::OK();
} }
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("cycle_index"), cycle_index_));
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("block_index"), block_index_));
if (end_of_input_) {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("end_of_input"), ""));
}
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("num_open"), num_open_));
TF_RETURN_IF_ERROR(SaveCurrentElements(writer));
return Status::OK();
}
Status RestoreInternal(OpKernelContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
int64 cycle_index;
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name("cycle_index"), &cycle_index));
cycle_index_ = size_t(cycle_index);
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name("block_index"), &block_index_));
if (reader->Contains(full_name("end_of_input"))) end_of_input_ = true;
int64 num_open;
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name("num_open"), &num_open));
num_open_ = size_t(num_open);
TF_RETURN_IF_ERROR(RestoreCurrentElements(ctx, reader));
return Status::OK();
}
private: private:
Status SaveCurrentElements(IteratorStateWriter* writer)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
for (int idx = 0; idx < current_elements_.size(); idx++) {
if (current_elements_[idx]) {
TF_RETURN_IF_ERROR(SaveParent(writer, current_elements_[idx]));
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat("args_size[", idx, "]")),
args_list_[idx].size()));
for (int i = 0; i < args_list_[idx].size(); i++) {
TF_RETURN_IF_ERROR(writer->WriteTensor(
full_name(strings::StrCat("args_list_[", idx, "][", i, "]")),
args_list_[idx][i]));
}
}
}
return Status::OK();
}
Status RestoreCurrentElements(OpKernelContext* ctx,
IteratorStateReader* reader)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
IteratorContext::Params params;
params.env = ctx->env();
params.runner = *(ctx->runner());
IteratorContext iter_ctx(std::move(params));
for (int idx = 0; idx < current_elements_.size(); idx++) {
if (reader->Contains(
full_name(strings::StrCat("args_size[", idx, "]")))) {
int64 args_size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat("args_size[", idx, "]")),
&args_size));
args_list_[idx].resize(args_size);
for (int i = 0; i < args_size; i++) {
TF_RETURN_IF_ERROR(reader->ReadTensor(
full_name(strings::StrCat("args_list_[", idx, "][", i, "]")),
&args_list_[idx][i]));
}
TF_RETURN_IF_ERROR(dataset::MakeIteratorFromInputElement(
&iter_ctx, args_list_[idx], idx,
dataset()->captured_func_.get(), prefix(),
&current_elements_[idx]));
TF_RETURN_IF_ERROR(
RestoreParent(ctx, reader, current_elements_[idx]));
} else {
current_elements_[idx].reset();
}
}
return Status::OK();
}
mutex mu_; mutex mu_;
const std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); const std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
std::vector<std::unique_ptr<IteratorBase>> current_elements_ std::vector<std::unique_ptr<IteratorBase>> current_elements_
GUARDED_BY(mu_); GUARDED_BY(mu_);
std::vector<std::vector<Tensor>> args_list_ GUARDED_BY(mu_);
size_t cycle_index_ GUARDED_BY(mu_) = 0; size_t cycle_index_ GUARDED_BY(mu_) = 0;
int64 block_index_ GUARDED_BY(mu_) = 0; int64 block_index_ GUARDED_BY(mu_) = 0;
bool end_of_input_ GUARDED_BY(mu_) = false; bool end_of_input_ GUARDED_BY(mu_) = false;
@ -185,6 +313,7 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
}; };
const DatasetBase* const input_; const DatasetBase* const input_;
const NameAttrList func_;
const std::unique_ptr<CapturedFunction> captured_func_; const std::unique_ptr<CapturedFunction> captured_func_;
const int64 cycle_length_; const int64 cycle_length_;
const int64 block_length_; const int64 block_length_;

View File

@ -258,7 +258,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
EnsureOutputAllocated(batch_result, result->return_values); EnsureOutputAllocated(batch_result, result->return_values);
const size_t num_components = result->return_values.size(); const size_t num_components = result->return_values.size();
for (size_t i = 0; i < num_components; ++i) { for (size_t i = 0; i < num_components; ++i) {
Tensor tensor = result->return_values[i]; const Tensor& tensor = result->return_values[i];
Tensor* batch = &(batch_result->output)[i]; Tensor* batch = &(batch_result->output)[i];
if (tensor.NumElements() != if (tensor.NumElements() !=
(batch->NumElements() / batch->dim_size(0))) { (batch->NumElements() / batch->dim_size(0))) {
@ -271,6 +271,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
", [batch]: ", batch_shape.DebugString())); ", [batch]: ", batch_shape.DebugString()));
break; break;
} }
// TODO(mrry): Add a version of DoParallelConcat that allows
// us to move `tensor` where possible, to speed up string
// tensor batching.
Status copy_status = ::tensorflow::functor::DoParallelConcat( Status copy_status = ::tensorflow::functor::DoParallelConcat(
*dataset()->device_, tensor, offset, batch); *dataset()->device_, tensor, offset, batch);
if (!copy_status.ok()) { if (!copy_status.ok()) {
@ -279,6 +282,11 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
} }
} }
} }
// NOTE(mrry): We clear the return values here to release any
// memory associated with them and to paralellize the destruction
// of the tensors (which can be surprisingly expensive for
// map functions with large numbers of return values).
result->return_values.clear();
batch_result->counter->DecrementCount(); batch_result->counter->DecrementCount();
}); });
} }
@ -297,7 +305,10 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
for (size_t i = 0; i < dataset()->batch_size_; ++i) { for (size_t i = 0; i < dataset()->batch_size_; ++i) {
size_t index = ComputeInvocationIndex(batch_index, i); size_t index = ComputeInvocationIndex(batch_index, i);
InvocationResult* result = &invocation_results_[index]; InvocationResult* result = &invocation_results_[index];
*result = InvocationResult(); // Reset the state of `result`.
// NOTE(mrry): `result->return_values` were cleared when the previous
// invocation completed.
result->status = Status::OK();
} }
// Start individual invocations. // Start individual invocations.
for (size_t i = 0; i < dataset()->batch_size_; ++i) { for (size_t i = 0; i < dataset()->batch_size_; ++i) {

View File

@ -359,7 +359,8 @@ class MaxPoolingGradOp<Eigen::GpuDevice, T> : public OpKernel {
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
use_dnn_ = CanUseCudnn(); use_dnn_ = CanUseCudnn();
ReadBoolFromEnvVar("TF_ENABLE_MAXPOOL_NANPROP", false, &propagate_nans_); TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_MAXPOOL_NANPROP", false,
&propagate_nans_));
} }
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
@ -888,7 +889,8 @@ class MaxPoolingWithArgmaxOp : public OpKernel {
errors::Unimplemented( errors::Unimplemented(
"Pooling is not yet supported on the batch dimension.")); "Pooling is not yet supported on the batch dimension."));
ReadBoolFromEnvVar("TF_ENABLE_MAXPOOL_NANPROP", false, &propagate_nans_); TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_MAXPOOL_NANPROP", false,
&propagate_nans_));
} }
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
@ -1052,7 +1054,8 @@ class MaxPoolingNoMaskOp<GPUDevice, T> : public OpKernel {
"Pooling is not yet supported on the batch dimension.")); "Pooling is not yet supported on the batch dimension."));
use_dnn_ = CanUseCudnn(); use_dnn_ = CanUseCudnn();
ReadBoolFromEnvVar("TF_ENABLE_MAXPOOL_NANPROP", false, &propagate_nans_); TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_MAXPOOL_NANPROP", false,
&propagate_nans_));
} }
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
@ -1137,7 +1140,8 @@ class MaxPoolingNoMaskV2Op<GPUDevice, T> : public OpKernel {
} }
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
use_dnn_ = CanUseCudnn(); use_dnn_ = CanUseCudnn();
ReadBoolFromEnvVar("TF_ENABLE_MAXPOOL_NANPROP", false, &propagate_nans_); TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_MAXPOOL_NANPROP", false,
&propagate_nans_));
} }
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {

View File

@ -405,15 +405,15 @@ bool MaxPoolForwardWithOptionalArgmax<T>::operator()(
if (propagate_nans) { if (propagate_nans) {
MaxPoolForwardNHWC<true> MaxPoolForwardNHWC<true>
<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, <<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
kThreadsPerBlock, 0, d.stream()>>> kThreadsPerBlock, 0, d.stream()>>>(
(output_size, bottom_data, height, width, channels, pooled_height, output_size, bottom_data, height, width, channels, pooled_height,
pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l, pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l,
top_data, mask); top_data, mask);
} else { } else {
MaxPoolForwardNHWC<false> MaxPoolForwardNHWC<false>
<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, <<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
kThreadsPerBlock, 0, d.stream()>>> kThreadsPerBlock, 0, d.stream()>>>(
(output_size, bottom_data, height, width, channels, pooled_height, output_size, bottom_data, height, width, channels, pooled_height,
pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l, pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l,
top_data, mask); top_data, mask);
} }

View File

@ -101,8 +101,8 @@ class MklToTfOp : public OpKernel {
// Allocate output tensor. // Allocate output tensor.
TensorShape output_shape = input_shape.GetTfShape(); TensorShape output_shape = input_shape.GetTfShape();
Tensor* output_tensor = NULL; Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(input_number, OP_REQUIRES_OK(context, context->allocate_output(
output_shape, &output_tensor)); input_number, output_shape, &output_tensor));
CHECK_NOTNULL(output_tensor); CHECK_NOTNULL(output_tensor);
// Do we need to reorder Mkl layout into TensorFlow layout? // Do we need to reorder Mkl layout into TensorFlow layout?
@ -118,10 +118,10 @@ class MklToTfOp : public OpKernel {
} }
} catch (mkldnn::error& e) { } catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) + string error_msg = "Status: " + std::to_string(e.status) +
", message: " + std::string(e.message) + ", message: " + std::string(e.message) + ", in file " +
", in file " + std::string(__FILE__) + ":" + std::string(__FILE__) + ":" + std::to_string(__LINE__);
std::to_string(__LINE__); OP_REQUIRES_OK(
OP_REQUIRES_OK(context, context,
errors::Aborted("Operation received an exception:", error_msg)); errors::Aborted("Operation received an exception:", error_msg));
} }
} }
@ -160,8 +160,8 @@ class MklToTfOp : public OpKernel {
// Allocate output tensor. // Allocate output tensor.
Tensor* output_tensor = NULL; Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(input_number, OP_REQUIRES_OK(context, context->allocate_output(input_number, output_shape,
output_shape, &output_tensor)); &output_tensor));
dnnLayout_t output_layout = dnnLayout_t output_layout =
static_cast<dnnLayout_t>(input_shape.GetTfLayout()); static_cast<dnnLayout_t>(input_shape.GetTfLayout());

View File

@ -98,6 +98,19 @@ gtl::InlinedVector<T, 8> ComputeStride(const TensorShape& shape) {
return strides; return strides;
} }
// Helper to compute 'strides' given an Eigen TensorDimensions
template <typename T, typename EigenDimensions>
gtl::InlinedVector<T, 8> ComputeEigenStrides(const EigenDimensions& shape) {
const int ndims = shape.rank();
gtl::InlinedVector<T, 8> strides(ndims);
T stride = 1;
for (int i = ndims - 1; i >= 0; --i) {
strides[i] = stride;
stride *= static_cast<T>(shape[i]);
}
return strides;
}
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_KERNELS_OPS_UTIL_H_ #endif // TENSORFLOW_KERNELS_OPS_UTIL_H_

View File

@ -181,16 +181,18 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
padding_values.push_back(tensor::DeepCopy(padding_value_t)); padding_values.push_back(tensor::DeepCopy(padding_value_t));
} }
*output = new Dataset(batch_size, std::move(padded_shapes), *output = new Dataset(ctx, batch_size, std::move(padded_shapes),
std::move(padding_values), input); std::move(padding_values), input);
} }
private: private:
class Dataset : public DatasetBase { class Dataset : public GraphDatasetBase {
public: public:
Dataset(int64 batch_size, std::vector<PartialTensorShape> padded_shapes, Dataset(OpKernelContext* ctx, int64 batch_size,
std::vector<PartialTensorShape> padded_shapes,
std::vector<Tensor> padding_values, const DatasetBase* input) std::vector<Tensor> padding_values, const DatasetBase* input)
: batch_size_(batch_size), : GraphDatasetBase(ctx),
batch_size_(batch_size),
padded_shapes_(std::move(padded_shapes)), padded_shapes_(std::move(padded_shapes)),
padding_values_(std::move(padding_values)), padding_values_(std::move(padding_values)),
input_(input) { input_(input) {
@ -232,6 +234,47 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
")::Dataset"); ")::Dataset");
} }
protected:
Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
Node** output) const override {
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node));
Node* batch_size = nullptr;
TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size));
std::vector<NodeBuilder::NodeOut> padded_shapes;
padded_shapes.reserve(padded_shapes_.size());
for (int i = 0; i < padded_shapes_.size(); i++) {
Node* node;
Tensor t(DT_INT64, TensorShape({padded_shapes_[i].dims()}));
for (int j = 0; j < padded_shapes_[i].dims(); j++) {
t.vec<int64>()(j) = padded_shapes_[i].dim_size(j);
}
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
padded_shapes.emplace_back(node);
}
std::vector<NodeBuilder::NodeOut> padding_values;
padding_values.reserve(padding_values_.size());
for (const Tensor& t : padding_values_) {
Node* node;
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
padding_values.emplace_back(node);
}
AttrValue output_types;
b->BuildAttrValue(output_dtypes(), &output_types);
AttrValue N;
b->BuildAttrValue<int64>(padded_shapes_.size(), &N);
TF_RETURN_IF_ERROR(
b->AddDataset(this, {{0, input_graph_node}, {1, batch_size}},
{{2, padded_shapes}, {3, padding_values}},
{{"Toutput_types", output_types}, {"N", N}}, output));
return Status::OK();
}
private: private:
// Copies element into the index^th slice of parent (in the 0th dimension). // Copies element into the index^th slice of parent (in the 0th dimension).
// //
@ -248,10 +291,14 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
// Each row of `batch_elements` is a tuple of tensors from the // Each row of `batch_elements` is a tuple of tensors from the
// input iterator. // input iterator.
std::vector<std::vector<Tensor>> batch_elements; std::vector<std::vector<Tensor>> batch_elements;
batch_elements.reserve(dataset()->batch_size_);
{ {
mutex_lock l(mu_); mutex_lock l(mu_);
if (!input_impl_) {
*end_of_sequence = true;
return Status::OK();
} else {
*end_of_sequence = false; *end_of_sequence = false;
batch_elements.reserve(dataset()->batch_size_);
for (int i = 0; i < dataset()->batch_size_ && !*end_of_sequence; for (int i = 0; i < dataset()->batch_size_ && !*end_of_sequence;
++i) { ++i) {
std::vector<Tensor> batch_element_tuple; std::vector<Tensor> batch_element_tuple;
@ -261,6 +308,10 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
batch_elements.push_back(std::move(batch_element_tuple)); batch_elements.push_back(std::move(batch_element_tuple));
} }
} }
if (*end_of_sequence) {
input_impl_.reset();
}
}
} }
if (batch_elements.empty()) { if (batch_elements.empty()) {
@ -347,6 +398,28 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
return Status::OK(); return Status::OK();
} }
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
if (input_impl_)
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
else
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("exhausted"), ""));
return Status::OK();
}
Status RestoreInternal(OpKernelContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
if (reader->Contains(full_name("exhausted"))) {
input_impl_.reset();
} else {
input_impl_ = dataset()->input_->MakeIterator(prefix());
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
}
return Status::OK();
}
private: private:
mutex mu_; mutex mu_;
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);

View File

@ -352,13 +352,15 @@ class DeserializeSparseOp : public OpKernel {
i, "] was: ", shape.dims() - 1, " but rank of SparseTensor[", i, i, "] was: ", shape.dims() - 1, " but rank of SparseTensor[", i,
"] is: ", expanded_tensor_shape.dims() - 1)); "] is: ", expanded_tensor_shape.dims() - 1));
for (int j = 1; j < shape.dims(); ++j) { for (int j = 1; j < shape.dims(); ++j) {
OP_REQUIRES( // NOTE(mrry): For compatibility with the implementations of
context, shape.dim_size(j) == expanded_tensor_shape.dim_size(j), // DeserializeManySparse, and many ops that generate
errors::InvalidArgument( // SparseTensors to batch that do not have a fixed
"Inconsistent shape across SparseTensors: dimension ", j - 1, // dense_shape (e.g. `tf.parse_single_example()`), we
" prior to SparseTensor[", i, "] was: ", shape.dim_size(j), // compute the maximum in each dimension to find the
" but rank of SparseTensor[", i, // smallest dense_shape that bounds all of the input
"] is: ", expanded_tensor_shape.dim_size(j))); // SparseTensors.
shape.set_dim(j, std::max(shape.dim_size(j),
expanded_tensor_shape.dim_size(j)));
} }
} }
} }

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/cloud/curl_http_request.h" #include "tensorflow/core/platform/cloud/curl_http_request.h"
#include "tensorflow/core/platform/cloud/file_block_cache.h" #include "tensorflow/core/platform/cloud/file_block_cache.h"
#include "tensorflow/core/platform/cloud/google_auth_provider.h" #include "tensorflow/core/platform/cloud/google_auth_provider.h"
@ -696,6 +697,18 @@ Status GcsFileSystem::LoadBufferFromGCS(const string& filename, size_t offset,
TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading gs://", TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading gs://",
bucket, "/", object); bucket, "/", object);
if (out->size() < block_size()) {
// Check stat cache to see if we encountered an interrupted read.
FileStatistics stat;
if (stat_cache_->Lookup(filename, &stat)) {
if (offset + out->size() < stat.length) {
return errors::Internal(strings::Printf(
"File contents are inconsistent for file: %s @ %lu.",
filename.c_str(), offset));
}
}
}
return Status::OK(); return Status::OK();
} }
@ -816,7 +829,8 @@ Status GcsFileSystem::StatForObject(const string& fname, const string& bucket,
return errors::Internal("'stat' cannot be nullptr."); return errors::Internal("'stat' cannot be nullptr.");
} }
if (object.empty()) { if (object.empty()) {
return errors::InvalidArgument("'object' must be a non-empty string."); return errors::InvalidArgument(strings::Printf(
"'object' must be a non-empty string. (File: %s)", fname.c_str()));
} }
StatCache::ComputeFunc compute_func = StatCache::ComputeFunc compute_func =

View File

@ -131,8 +131,8 @@ error::Code ErrnoToCode(int err_number) {
case ENETUNREACH: // Network unreachable case ENETUNREACH: // Network unreachable
case ENOLCK: // No locks available case ENOLCK: // No locks available
case ENOLINK: // Link has been severed case ENOLINK: // Link has been severed
#if !(defined(__APPLE__) || defined(__FreeBSD__) || defined(_WIN32) \ #if !(defined(__APPLE__) || defined(__FreeBSD__) || defined(_WIN32) || \
|| defined(__HAIKU__)) defined(__HAIKU__))
case ENONET: // Machine is not on the network case ENONET: // Machine is not on the network
#endif #endif
code = error::UNAVAILABLE; code = error::UNAVAILABLE;

View File

@ -37,8 +37,8 @@ limitations under the License.
#ifdef TF_USE_SNAPPY #ifdef TF_USE_SNAPPY
#include "snappy.h" #include "snappy.h"
#endif #endif
#if (defined(__APPLE__) && defined(__MACH__)) || defined(__FreeBSD__) \ #if (defined(__APPLE__) && defined(__MACH__)) || defined(__FreeBSD__) || \
|| defined(__HAIKU__) defined(__HAIKU__)
#include <thread> #include <thread>
#endif #endif
@ -62,8 +62,8 @@ int NumSchedulableCPUs() {
} }
perror("sched_getaffinity"); perror("sched_getaffinity");
#endif #endif
#if (defined(__APPLE__) && defined(__MACH__)) || defined(__FreeBSD__) \ #if (defined(__APPLE__) && defined(__MACH__)) || defined(__FreeBSD__) || \
|| defined(__HAIKU__) defined(__HAIKU__)
unsigned int count = std::thread::hardware_concurrency(); unsigned int count = std::thread::hardware_concurrency();
if (count > 0) return static_cast<int>(count); if (count > 0) return static_cast<int>(count);
#endif #endif

View File

@ -752,6 +752,12 @@ __device__ EIGEN_ALWAYS_INLINE T CudaShuffleDown(unsigned mask, T value,
return __shfl_down_sync(mask, value, delta, width); return __shfl_down_sync(mask, value, delta, width);
} }
__device__ EIGEN_ALWAYS_INLINE Eigen::half CudaShuffleDown(
unsigned mask, Eigen::half value, int delta, int width = warpSize) {
return Eigen::half(
__shfl_down_sync(mask, static_cast<uint16>(value), delta, width));
}
// Variant of the (undocumented) version from the CUDA SDK, but using unsigned // Variant of the (undocumented) version from the CUDA SDK, but using unsigned
// instead of float for lo and hi (which is incorrect with ftz, for example). // instead of float for lo and hi (which is incorrect with ftz, for example).
// A bug has been filed with NVIDIA and will be fixed in the next CUDA release. // A bug has been filed with NVIDIA and will be fixed in the next CUDA release.
@ -774,6 +780,12 @@ __device__ EIGEN_ALWAYS_INLINE T CudaShuffleXor(unsigned mask, T value,
return __shfl_xor_sync(mask, value, laneMask, width); return __shfl_xor_sync(mask, value, laneMask, width);
} }
__device__ EIGEN_ALWAYS_INLINE Eigen::half CudaShuffleXor(
unsigned mask, Eigen::half value, int laneMask, int width = warpSize) {
return Eigen::half(
__shfl_xor_sync(mask, static_cast<uint16>(value), laneMask, width));
}
// Variant of the (undocumented) version from the CUDA SDK, but using unsigned // Variant of the (undocumented) version from the CUDA SDK, but using unsigned
// instead of float for lo and hi (which is incorrect with ftz, for example). // instead of float for lo and hi (which is incorrect with ftz, for example).
// A bug has been filed with NVIDIA and will be fixed in the next CUDA release. // A bug has been filed with NVIDIA and will be fixed in the next CUDA release.

View File

@ -24,25 +24,25 @@ limitations under the License.
#include "mkl_dnn_types.h" #include "mkl_dnn_types.h"
#include "mkl_service.h" #include "mkl_service.h"
#include "mkl_trans.h" #include "mkl_trans.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/graph/mkl_graph_util.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h" #include "tensorflow/core/util/tensor_format.h"
#include "tensorflow/core/graph/mkl_graph_util.h"
#ifdef INTEL_MKL_DNN #ifdef INTEL_MKL_DNN
#include "mkldnn.hpp" #include "mkldnn.hpp"
using mkldnn::memory;
using mkldnn::reorder;
using mkldnn::primitive;
using mkldnn::padding_kind;
using mkldnn::engine; using mkldnn::engine;
using mkldnn::memory;
using mkldnn::padding_kind;
using mkldnn::primitive;
using mkldnn::reorder;
#endif #endif
// The file contains a number of utility classes and functions used by MKL // The file contains a number of utility classes and functions used by MKL
@ -56,8 +56,14 @@ namespace tensorflow {
// Tensorflow tensor. // Tensorflow tensor.
typedef enum { W = 0, H = 1, C = 2, N = 3 } MklDims; typedef enum { W = 0, H = 1, C = 2, N = 3 } MklDims;
typedef enum { Dim_N = 0, Dim_C = 1, Dim_H = 2, Dim_W = 3, typedef enum {
Dim_O = 0, Dim_I = 1 } MklDnnDims; Dim_N = 0,
Dim_C = 1,
Dim_H = 2,
Dim_W = 3,
Dim_O = 0,
Dim_I = 1
} MklDnnDims;
class MklShape { class MklShape {
public: public:
@ -236,8 +242,7 @@ class MklShape {
(IS_MKL_TENSOR_OFFSET + sizeof(size_t)) // Location of dimension_ (IS_MKL_TENSOR_OFFSET + sizeof(size_t)) // Location of dimension_
// Location of sizes. Note dim is not used here, left here // Location of sizes. Note dim is not used here, left here
// to make macros consistent. // to make macros consistent.
#define SIZES_OFFSET(dims) \ #define SIZES_OFFSET(dims) (DIMS_OFFSET + sizeof(size_t))
(DIMS_OFFSET + sizeof(size_t))
#define STRIDES_OFFSET(dims) \ #define STRIDES_OFFSET(dims) \
(SIZES_OFFSET(dims) + dims * sizeof(size_t)) // Location of strides (SIZES_OFFSET(dims) + dims * sizeof(size_t)) // Location of strides
#define MKL_LAYOUT_OFFSET(dims) \ #define MKL_LAYOUT_OFFSET(dims) \
@ -345,15 +350,13 @@ class MklDnnShape {
typedef std::remove_extent<mkldnn_dims_t>::type mkldnn_dim_t; typedef std::remove_extent<mkldnn_dims_t>::type mkldnn_dim_t;
#define INVALID_DIM_SIZE -1 #define INVALID_DIM_SIZE -1
public: public:
MklDnnShape() { MklDnnShape() {
for (size_t i = 0; i < sizeof(data_.sizes_) / for (size_t i = 0; i < sizeof(data_.sizes_) / sizeof(data_.sizes_[0]);
sizeof(data_.sizes_[0]); ++i) { ++i) {
data_.sizes_[i] = -1; data_.sizes_[i] = -1;
} }
for (size_t i = 0; i < sizeof(data_.map_) / for (size_t i = 0; i < sizeof(data_.map_) / sizeof(data_.map_[0]); ++i) {
sizeof(data_.map_[0]); ++i) {
data_.map_[i] = -1; data_.map_[i] = -1;
} }
} }
@ -497,9 +500,7 @@ class MklDnnShape {
SetTfDimOrder(dimension, data_format); SetTfDimOrder(dimension, data_format);
} }
inline const mkldnn_dim_t* GetTfToMklDimMap() const { inline const mkldnn_dim_t* GetTfToMklDimMap() const { return &data_.map_[0]; }
return &data_.map_[0];
}
inline size_t TfDimIdx(int index) const { return data_.map_[index]; } inline size_t TfDimIdx(int index) const { return data_.map_[index]; }
inline int64 TfDimSize(int index) const { inline int64 TfDimSize(int index) const {
return data_.sizes_[TfDimIdx(index)]; return data_.sizes_[TfDimIdx(index)];
@ -553,9 +554,7 @@ class MklDnnShape {
/// Size of buffer to hold the serialized object, the size is computed by /// Size of buffer to hold the serialized object, the size is computed by
/// following above mentioned order /// following above mentioned order
inline size_t GetSerializeBufferSize() const { inline size_t GetSerializeBufferSize() const { return sizeof(MklShapeData); }
return sizeof(MklShapeData);
}
void SerializeMklDnnShape(unsigned char* buf, size_t buf_size) const { void SerializeMklDnnShape(unsigned char* buf, size_t buf_size) const {
CHECK(buf_size >= GetSerializeBufferSize()) CHECK(buf_size >= GetSerializeBufferSize())
@ -660,8 +659,7 @@ inline void GetMklShape(OpKernelContext* ctext, int n, MklShape* mklshape) {
} }
#ifdef INTEL_MKL_DNN #ifdef INTEL_MKL_DNN
inline void GetMklShape(OpKernelContext* ctext, int n, inline void GetMklShape(OpKernelContext* ctext, int n, MklDnnShape* mklshape) {
MklDnnShape* mklshape) {
mklshape->DeSerializeMklDnnShape( mklshape->DeSerializeMklDnnShape(
ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs())) ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
.flat<uint8>() .flat<uint8>()
@ -700,8 +698,7 @@ inline void GetMklShapeList(OpKernelContext* ctext, StringPiece name,
/// Get shape of input tensor pointed by 'input_idx' in TensorShape format. /// Get shape of input tensor pointed by 'input_idx' in TensorShape format.
/// If the input tensor is in MKL layout, then obtains TensorShape from /// If the input tensor is in MKL layout, then obtains TensorShape from
/// MklShape. /// MklShape.
inline TensorShape GetTfShape(OpKernelContext* context, inline TensorShape GetTfShape(OpKernelContext* context, size_t input_idx) {
size_t input_idx) {
// Sanity check. // Sanity check.
CHECK_NOTNULL(context); CHECK_NOTNULL(context);
CHECK_LT(input_idx, context->num_inputs()); CHECK_LT(input_idx, context->num_inputs());
@ -1099,7 +1096,8 @@ inline void MklNCHWToNHWC(const Tensor& input, Tensor** output) {
/// ///
/// @input None /// @input None
/// @return memory::data_type corresponding to type T /// @return memory::data_type corresponding to type T
template<typename T> static memory::data_type MklDnnType(); template <typename T>
static memory::data_type MklDnnType();
/// Instantiation for float type. Add similar instantiations for other /// Instantiation for float type. Add similar instantiations for other
/// type if needed. /// type if needed.
@ -1114,10 +1112,11 @@ memory::data_type MklDnnType<float>() {
/// @return: memory::format corresponding to TensorFlow data format; /// @return: memory::format corresponding to TensorFlow data format;
/// Fails with an error if invalid data format. /// Fails with an error if invalid data format.
inline memory::format TFDataFormatToMklDnnDataFormat(TensorFormat format) { inline memory::format TFDataFormatToMklDnnDataFormat(TensorFormat format) {
if (format == FORMAT_NHWC) return memory::format::nhwc; if (format == FORMAT_NHWC)
else if (format == FORMAT_NCHW) return memory::format::nchw; return memory::format::nhwc;
TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, else if (format == FORMAT_NCHW)
"Unsupported data format")); return memory::format::nchw;
TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
// Return to get rid of compiler warning // Return to get rid of compiler warning
return memory::format::format_undef; return memory::format::format_undef;
} }
@ -1128,10 +1127,11 @@ inline memory::format TFDataFormatToMklDnnDataFormat(TensorFormat format) {
/// @return: Tensorflow data format corresponding to memory::format /// @return: Tensorflow data format corresponding to memory::format
/// Fails with an error if invalid data format. /// Fails with an error if invalid data format.
inline TensorFormat MklDnnDataFormatToTFDataFormat(memory::format format) { inline TensorFormat MklDnnDataFormatToTFDataFormat(memory::format format) {
if (format == memory::format::nhwc) return FORMAT_NHWC; if (format == memory::format::nhwc)
else if (format == memory::format::nchw) return FORMAT_NCHW; return FORMAT_NHWC;
TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, else if (format == memory::format::nchw)
"Unsupported data format")); return FORMAT_NCHW;
TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
} }
/// Map TensorShape object into memory::dims required by MKL-DNN /// Map TensorShape object into memory::dims required by MKL-DNN
@ -1237,9 +1237,11 @@ class MklDnnData {
const engine* cpu_engine_; const engine* cpu_engine_;
public: public:
explicit MklDnnData(const engine* e) : user_memory_(nullptr), explicit MklDnnData(const engine* e)
: user_memory_(nullptr),
reorder_memory_(nullptr), reorder_memory_(nullptr),
op_md_(nullptr), cpu_engine_(e) {} op_md_(nullptr),
cpu_engine_(e) {}
~MklDnnData() { ~MklDnnData() {
cpu_engine_ = nullptr; // We don't own this. cpu_engine_ = nullptr; // We don't own this.
@ -1250,8 +1252,8 @@ class MklDnnData {
inline void* GetTensorBuffer(const Tensor* tensor) const { inline void* GetTensorBuffer(const Tensor* tensor) const {
CHECK_NOTNULL(tensor); CHECK_NOTNULL(tensor);
return const_cast<void*>(static_cast<const void*>( return const_cast<void*>(
tensor->flat<T>().data())); static_cast<const void*>(tensor->flat<T>().data()));
} }
/// Set user memory primitive using specified dimensions, memory format and /// Set user memory primitive using specified dimensions, memory format and

Some files were not shown because too many files have changed in this diff Show More