Merge pull request #14814 from yifeif/branch_176709725
Branch 176709725
This commit is contained in:
commit
ab0fcaceda
22
configure.py
22
configure.py
@ -905,6 +905,28 @@ def set_trisycl_include_dir(environ_cp):
|
||||
write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR',
|
||||
trisycl_include_dir)
|
||||
|
||||
def set_trisycl_include_dir(environ_cp):
|
||||
"""Set TRISYCL_INCLUDE_DIR."""
|
||||
ask_trisycl_include_dir = ('Please specify the location of the triSYCL '
|
||||
'include directory. (Use --config=sycl_trisycl '
|
||||
'when building with Bazel) '
|
||||
'[Default is %s]: ') % (
|
||||
_DEFAULT_TRISYCL_INCLUDE_DIR)
|
||||
while True:
|
||||
trisycl_include_dir = get_from_env_or_user_or_default(
|
||||
environ_cp, 'TRISYCL_INCLUDE_DIR', ask_trisycl_include_dir,
|
||||
_DEFAULT_TRISYCL_INCLUDE_DIR)
|
||||
if os.path.exists(trisycl_include_dir):
|
||||
break
|
||||
|
||||
print('Invalid triSYCL include directory, %s cannot be found' %
|
||||
(trisycl_include_dir))
|
||||
|
||||
# Set TRISYCL_INCLUDE_DIR
|
||||
environ_cp['TRISYCL_INCLUDE_DIR'] = trisycl_include_dir
|
||||
write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR', trisycl_include_dir)
|
||||
|
||||
|
||||
def set_mpi_home(environ_cp):
|
||||
"""Set MPI_HOME."""
|
||||
default_mpi_home = which('mpirun') or which('mpiexec') or ''
|
||||
|
@ -189,7 +189,7 @@ def tf_library(name, graph, config,
|
||||
" --cpp_class=" + cpp_class +
|
||||
" --target_triple=" + target_llvm_triple() +
|
||||
" --out_session_module=$(@D)/" + session_module_pb +
|
||||
flags),
|
||||
" " + flags),
|
||||
tools=[tfcompile_tool],
|
||||
visibility=visibility,
|
||||
testonly=testonly,
|
||||
|
@ -76,7 +76,8 @@ class FusedBatchNormTest(XLATestCase):
|
||||
# To avoid constant folding
|
||||
t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x")
|
||||
scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale")
|
||||
offset = array_ops.placeholder(np.float32, shape=scale_shape, name="offset")
|
||||
offset = array_ops.placeholder(
|
||||
np.float32, shape=scale_shape, name="offset")
|
||||
epsilon = 0.001
|
||||
y_ref, mean_ref, var_ref = self._reference_training(
|
||||
x_val, scale_val, offset_val, epsilon, data_format)
|
||||
@ -112,7 +113,8 @@ class FusedBatchNormTest(XLATestCase):
|
||||
# To avoid constant folding
|
||||
t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x")
|
||||
scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale")
|
||||
offset = array_ops.placeholder(np.float32, shape=scale_shape, name="offset")
|
||||
offset = array_ops.placeholder(
|
||||
np.float32, shape=scale_shape, name="offset")
|
||||
epsilon = 0.001
|
||||
y, mean, var = nn.fused_batch_norm(
|
||||
t_val,
|
||||
|
@ -67,6 +67,15 @@ class Client {
|
||||
std::vector<GlobalData*> arguments;
|
||||
ExecutionOptions execution_options;
|
||||
ExecutionProfile* execution_profile;
|
||||
|
||||
ComputationInstance(const Computation& computation,
|
||||
std::vector<GlobalData*> arguments,
|
||||
ExecutionOptions execution_options,
|
||||
ExecutionProfile* execution_profile)
|
||||
: computation(computation),
|
||||
arguments(std::move(arguments)),
|
||||
execution_options(execution_options),
|
||||
execution_profile(execution_profile) {}
|
||||
};
|
||||
|
||||
// Executes a list ComputationInstances and returns global data produced from
|
||||
@ -133,7 +142,7 @@ class Client {
|
||||
|
||||
// Returns a vector of global data handles that point to the tuple elements.
|
||||
StatusOr<std::vector<std::unique_ptr<GlobalData>>> DeconstructTuple(
|
||||
const GlobalData& computation);
|
||||
const GlobalData& data);
|
||||
|
||||
// Retrieves the statistics of the given computation.
|
||||
StatusOr<ComputationStats> GetComputationStats(
|
||||
|
@ -85,9 +85,9 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault {
|
||||
HloOpcode opcode) {
|
||||
HloComputation::Builder b("scalar_computation");
|
||||
auto scalar_lhs = b.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeShape(F32, {}), "scalar_lhs"));
|
||||
0, ShapeUtil::MakeShape(primitive_type, {}), "scalar_lhs"));
|
||||
auto scalar_rhs = b.AddInstruction(HloInstruction::CreateParameter(
|
||||
1, ShapeUtil::MakeShape(F32, {}), "scalar_rhs"));
|
||||
1, ShapeUtil::MakeShape(primitive_type, {}), "scalar_rhs"));
|
||||
auto scalar_op = b.AddInstruction(
|
||||
HloInstruction::CreateBinary(ShapeUtil::MakeShape(primitive_type, {}),
|
||||
opcode, scalar_lhs, scalar_rhs));
|
||||
@ -152,22 +152,30 @@ Status BatchNormRewriterVisitor::HandleBatchNormTraining(
|
||||
// Expand batch norm training into smaller HLO ops.
|
||||
HloInstruction* operand = batch_norm->mutable_operand(0);
|
||||
const Shape operand_shape = operand->shape();
|
||||
PrimitiveType ptype = operand_shape.element_type();
|
||||
int64 feature_index = batch_norm->feature_index();
|
||||
const int64 feature_count = operand_shape.dimensions(feature_index);
|
||||
const int64 size_in_elements = ShapeUtil::ElementsIn(operand_shape);
|
||||
auto elements_per_feature =
|
||||
computation_->AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::CreateR0<float>(size_in_elements / feature_count)));
|
||||
auto elements_per_feature_literal =
|
||||
Literal::CreateR0<float>(size_in_elements / feature_count);
|
||||
TF_ASSIGN_OR_RETURN(elements_per_feature_literal,
|
||||
elements_per_feature_literal->Convert(ptype));
|
||||
auto elements_per_feature = computation_->AddInstruction(
|
||||
HloInstruction::CreateConstant(std::move(elements_per_feature_literal)));
|
||||
|
||||
HloInstruction* scale = batch_norm->mutable_operand(1);
|
||||
HloInstruction* offset = batch_norm->mutable_operand(2);
|
||||
const Shape feature_shape = scale->shape();
|
||||
|
||||
auto zero_literal = Literal::CreateR0(0.0f);
|
||||
TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype));
|
||||
auto zero = computation_->AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0(0.0f)));
|
||||
HloInstruction::CreateConstant(std::move(zero_literal)));
|
||||
|
||||
auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon());
|
||||
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
|
||||
auto epsilon = computation_->AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon())));
|
||||
HloInstruction::CreateConstant(std::move(epsilon_literal)));
|
||||
|
||||
std::vector<int64> dimensions_without_feature;
|
||||
|
||||
@ -184,7 +192,7 @@ Status BatchNormRewriterVisitor::HandleBatchNormTraining(
|
||||
HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index}));
|
||||
|
||||
HloComputation* add_reduce_computation =
|
||||
GetScalarBinaryComputation(F32, HloOpcode::kAdd);
|
||||
GetScalarBinaryComputation(ptype, HloOpcode::kAdd);
|
||||
|
||||
// X^2.
|
||||
auto operand_squared =
|
||||
@ -243,8 +251,10 @@ Status BatchNormRewriterVisitor::HandleBatchNormTraining(
|
||||
computation_->AddInstruction(HloInstruction::CreateBinary(
|
||||
operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon));
|
||||
|
||||
auto neg_half_literal = Literal::CreateR0(-0.5f);
|
||||
TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype));
|
||||
auto neg_half = computation_->AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0(-0.5f)));
|
||||
HloInstruction::CreateConstant(std::move(neg_half_literal)));
|
||||
|
||||
// 1 / Sqrt[Var[X] + epsilon].
|
||||
auto rsqrt_var_add_epsilon =
|
||||
@ -286,6 +296,7 @@ Status BatchNormRewriterVisitor::HandleBatchNormInference(
|
||||
HloInstruction* operand = batch_norm->mutable_operand(0);
|
||||
const Shape operand_shape = operand->shape();
|
||||
int64 feature_index = batch_norm->feature_index();
|
||||
PrimitiveType ptype = operand_shape.element_type();
|
||||
|
||||
HloInstruction* scale = batch_norm->mutable_operand(1);
|
||||
HloInstruction* offset = batch_norm->mutable_operand(2);
|
||||
@ -293,8 +304,10 @@ Status BatchNormRewriterVisitor::HandleBatchNormInference(
|
||||
HloInstruction* var = batch_norm->mutable_operand(4);
|
||||
const Shape feature_shape = scale->shape();
|
||||
|
||||
auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon());
|
||||
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
|
||||
auto epsilon = computation_->AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon())));
|
||||
HloInstruction::CreateConstant(std::move(epsilon_literal)));
|
||||
|
||||
std::vector<int64> dimensions_without_feature;
|
||||
|
||||
@ -321,8 +334,10 @@ Status BatchNormRewriterVisitor::HandleBatchNormInference(
|
||||
computation_->AddInstruction(HloInstruction::CreateBinary(
|
||||
operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon));
|
||||
|
||||
auto neg_half_literal = Literal::CreateR0(-0.5f);
|
||||
TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype));
|
||||
auto neg_half = computation_->AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0(-0.5f)));
|
||||
HloInstruction::CreateConstant(std::move(neg_half_literal)));
|
||||
|
||||
// 1 / Sqrt[Var[X] + epsilon].
|
||||
auto rsqrt_var_add_epsilon =
|
||||
@ -373,6 +388,7 @@ Status BatchNormRewriterVisitor::HandleBatchNormGrad(
|
||||
|
||||
HloInstruction* activation = batch_norm->mutable_operand(0);
|
||||
const Shape activation_shape = activation->shape();
|
||||
PrimitiveType ptype = activation_shape.element_type();
|
||||
HloInstruction* scale = batch_norm->mutable_operand(1);
|
||||
const Shape feature_shape = scale->shape();
|
||||
HloInstruction* mean = batch_norm->mutable_operand(2);
|
||||
@ -383,18 +399,27 @@ Status BatchNormRewriterVisitor::HandleBatchNormGrad(
|
||||
|
||||
const int64 size_in_elements = ShapeUtil::ElementsIn(activation_shape);
|
||||
const int64 feature_count = activation_shape.dimensions(feature_index);
|
||||
auto elements_per_feature =
|
||||
computation_->AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::CreateR0<float>(size_in_elements / feature_count)));
|
||||
auto elements_per_feature_literal =
|
||||
Literal::CreateR0<float>(size_in_elements / feature_count);
|
||||
TF_ASSIGN_OR_RETURN(elements_per_feature_literal,
|
||||
elements_per_feature_literal->Convert(ptype));
|
||||
auto elements_per_feature = computation_->AddInstruction(
|
||||
HloInstruction::CreateConstant(std::move(elements_per_feature_literal)));
|
||||
|
||||
auto zero_literal = Literal::CreateR0(0.0f);
|
||||
TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype));
|
||||
auto zero = computation_->AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0(0.0f)));
|
||||
HloInstruction::CreateConstant(std::move(zero_literal)));
|
||||
|
||||
auto neg_half_literal = Literal::CreateR0(-0.5f);
|
||||
TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype));
|
||||
auto neg_half = computation_->AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0(-0.5f)));
|
||||
HloInstruction::CreateConstant(std::move(neg_half_literal)));
|
||||
|
||||
auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon());
|
||||
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
|
||||
auto epsilon = computation_->AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon())));
|
||||
HloInstruction::CreateConstant(std::move(epsilon_literal)));
|
||||
|
||||
std::vector<int64> dimensions_without_feature;
|
||||
|
||||
@ -442,7 +467,7 @@ Status BatchNormRewriterVisitor::HandleBatchNormGrad(
|
||||
grad_output, activation_minus_mean));
|
||||
|
||||
HloComputation* add_reduce_computation =
|
||||
GetScalarBinaryComputation(F32, HloOpcode::kAdd);
|
||||
GetScalarBinaryComputation(ptype, HloOpcode::kAdd);
|
||||
|
||||
// sum(Grad[Y] * (X - E[X])).
|
||||
auto sum_grad_output_times_activiation_minus_mean =
|
||||
|
@ -197,28 +197,35 @@ void InitializeLLVMCommandLineOptions(const HloModuleConfig& config) {
|
||||
class CollectProfileCandidates : public DfsHloVisitorWithDefault {
|
||||
public:
|
||||
static StatusOr<std::unordered_map<const HloInstruction*, size_t>>
|
||||
GetCandidatesForComputation(HloComputation* computation) {
|
||||
GetCandidatesForComputation(
|
||||
HloComputation* computation,
|
||||
const std::unordered_map<const HloInstruction*, int64>&
|
||||
assigned_indices) {
|
||||
std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx;
|
||||
CollectProfileCandidates profile_candidates_for_computation(
|
||||
&hlo_to_profile_idx);
|
||||
&hlo_to_profile_idx, assigned_indices);
|
||||
TF_RETURN_IF_ERROR(
|
||||
computation->Accept(&profile_candidates_for_computation));
|
||||
return hlo_to_profile_idx;
|
||||
}
|
||||
|
||||
private:
|
||||
explicit CollectProfileCandidates(
|
||||
std::unordered_map<const HloInstruction*, size_t>* hlo_to_profile_idx)
|
||||
: hlo_to_profile_idx_(hlo_to_profile_idx) {}
|
||||
CollectProfileCandidates(
|
||||
std::unordered_map<const HloInstruction*, size_t>* hlo_to_profile_idx,
|
||||
const std::unordered_map<const HloInstruction*, int64>& assigned_indices)
|
||||
: hlo_to_profile_idx_(hlo_to_profile_idx),
|
||||
assigned_indices_(assigned_indices) {}
|
||||
|
||||
Status DefaultAction(HloInstruction* hlo_instruction) override {
|
||||
hlo_to_profile_idx_->insert({hlo_instruction, hlo_to_profile_idx_->size()});
|
||||
hlo_to_profile_idx_->insert(
|
||||
{hlo_instruction, FindOrDie(assigned_indices_, hlo_instruction)});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HandleCall(HloInstruction* call) override {
|
||||
TF_RETURN_IF_ERROR(DefaultAction(call));
|
||||
CollectProfileCandidates candidates_for_call(hlo_to_profile_idx_);
|
||||
CollectProfileCandidates candidates_for_call(hlo_to_profile_idx_,
|
||||
assigned_indices_);
|
||||
TF_RETURN_IF_ERROR(call->to_apply()->Accept(&candidates_for_call));
|
||||
return Status::OK();
|
||||
}
|
||||
@ -232,17 +239,20 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault {
|
||||
Status HandleWhile(HloInstruction* xla_while) override {
|
||||
TF_RETURN_IF_ERROR(DefaultAction(xla_while));
|
||||
|
||||
CollectProfileCandidates candidates_for_condition(hlo_to_profile_idx_);
|
||||
CollectProfileCandidates candidates_for_condition(hlo_to_profile_idx_,
|
||||
assigned_indices_);
|
||||
TF_RETURN_IF_ERROR(
|
||||
xla_while->while_condition()->Accept(&candidates_for_condition));
|
||||
|
||||
CollectProfileCandidates candidates_for_body(hlo_to_profile_idx_);
|
||||
CollectProfileCandidates candidates_for_body(hlo_to_profile_idx_,
|
||||
assigned_indices_);
|
||||
TF_RETURN_IF_ERROR(xla_while->while_body()->Accept(&candidates_for_body));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::unordered_map<const HloInstruction*, size_t>* hlo_to_profile_idx_;
|
||||
const std::unordered_map<const HloInstruction*, int64>& assigned_indices_;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
@ -475,10 +485,27 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
|
||||
|
||||
HloComputation* computation = module->entry_computation();
|
||||
std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx;
|
||||
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map;
|
||||
std::unique_ptr<HloProfilePrinter> hlo_profile_printer;
|
||||
if (module->config().hlo_profiling_enabled()) {
|
||||
hlo_profile_index_map = MakeUnique<HloProfileIndexMap>(*module);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
hlo_to_profile_idx,
|
||||
CollectProfileCandidates::GetCandidatesForComputation(computation));
|
||||
CollectProfileCandidates::GetCandidatesForComputation(
|
||||
computation, hlo_profile_index_map->instruction_to_profile_idx()));
|
||||
|
||||
auto shape_size_bytes = [](const Shape& shape) {
|
||||
// On the cpu, opaques are pointers.
|
||||
if (ShapeUtil::IsOpaque(shape)) {
|
||||
return static_cast<int64>(sizeof(void*));
|
||||
}
|
||||
return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
|
||||
};
|
||||
|
||||
HloCostAnalysis cost_analysis(shape_size_bytes);
|
||||
hlo_profile_printer =
|
||||
CreateHloProfilePrinter(*hlo_profile_index_map, cost_analysis);
|
||||
}
|
||||
|
||||
std::unique_ptr<Executable> cpu_executable;
|
||||
@ -544,8 +571,16 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
|
||||
parallel_computations.emplace(to_apply, instruction);
|
||||
}
|
||||
|
||||
// We always profile the entire computation as a whole, even if hlo
|
||||
// profiling is disabled. When hlo profiling is diabled, we pass in a
|
||||
// profile counter array of just one element, which corresponds to the whole
|
||||
// computation.
|
||||
size_t entry_computation_profile_idx =
|
||||
hlo_profile_index_map ? hlo_profile_index_map->GetProfileIndexFor(
|
||||
*module->entry_computation())
|
||||
: 0;
|
||||
IrEmitter ir_emitter(*module, *assignment, llvm_module.get(),
|
||||
hlo_to_profile_idx, hlo_to_profile_idx.size(),
|
||||
hlo_to_profile_idx, entry_computation_profile_idx,
|
||||
jit->target_machine(), jit->external_constant_pool());
|
||||
|
||||
std::unique_ptr<HloInstructionMap<string>> function_names(
|
||||
@ -586,8 +621,8 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
|
||||
jit->AddModule(std::move(llvm_module));
|
||||
cpu_executable.reset(new ParallelCpuExecutable(
|
||||
std::move(jit), std::move(assignment), std::move(module),
|
||||
std::move(function_names), std::move(hlo_to_profile_idx),
|
||||
std::move(aligned_constants)));
|
||||
std::move(function_names), std::move(aligned_constants),
|
||||
std::move(hlo_profile_printer), std::move(hlo_profile_index_map)));
|
||||
|
||||
if (embed_ir_in_executable) {
|
||||
static_cast<CpuExecutable&>(*cpu_executable)
|
||||
@ -620,12 +655,22 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
|
||||
TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory(
|
||||
proto, xla_dump_hlo_proto_to, module->name()));
|
||||
}
|
||||
// We always profile the entire computation as a whole, even if hlo
|
||||
// profiling is disabled. When hlo profiling is diabled, we pass in a
|
||||
// profile counter array of just one element, which corresponds to the whole
|
||||
// computation.
|
||||
size_t entry_computation_profile_idx =
|
||||
hlo_profile_index_map ? hlo_profile_index_map->GetProfileIndexFor(
|
||||
*module->entry_computation())
|
||||
: 0;
|
||||
|
||||
// Each computation is a single function. Emit all embedded computations
|
||||
// before the entry computation. The order of computations returned from
|
||||
// GetEmbeddedComputations guarantees that a called computation occurs
|
||||
// before a caller computation.
|
||||
|
||||
IrEmitter ir_emitter(*module, *assignment, llvm_module.get(),
|
||||
hlo_to_profile_idx, hlo_to_profile_idx.size(),
|
||||
hlo_to_profile_idx, entry_computation_profile_idx,
|
||||
jit->target_machine(), jit->external_constant_pool());
|
||||
|
||||
for (auto embedded_computation :
|
||||
@ -659,7 +704,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
|
||||
jit->AddModule(std::move(llvm_module));
|
||||
cpu_executable.reset(new CpuExecutable(
|
||||
std::move(jit), std::move(assignment), std::move(module), function_name,
|
||||
std::move(hlo_to_profile_idx)));
|
||||
std::move(hlo_profile_printer), std::move(hlo_profile_index_map)));
|
||||
|
||||
if (embed_ir_in_executable) {
|
||||
static_cast<CpuExecutable&>(*cpu_executable)
|
||||
|
@ -43,6 +43,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/mem.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/stream_executor/host/host_stream.h"
|
||||
|
||||
namespace se = ::perftools::gputools;
|
||||
|
||||
@ -54,11 +55,12 @@ CpuExecutable::CpuExecutable(
|
||||
std::unique_ptr<const BufferAssignment> assignment,
|
||||
std::unique_ptr<const HloModule> hlo_module,
|
||||
const string& entry_function_name,
|
||||
std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx)
|
||||
: Executable(std::move(hlo_module)),
|
||||
std::unique_ptr<HloProfilePrinter> hlo_profile_printer,
|
||||
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map)
|
||||
: Executable(std::move(hlo_module), std::move(hlo_profile_printer),
|
||||
std::move(hlo_profile_index_map)),
|
||||
jit_(std::move(jit)),
|
||||
assignment_(std::move(assignment)),
|
||||
hlo_to_profile_idx_(std::move(hlo_to_profile_idx)) {
|
||||
assignment_(std::move(assignment)) {
|
||||
// Resolve symbols in the constructor rather than at execution time to avoid
|
||||
// races because FindSymbol is not thread safe.
|
||||
llvm::JITSymbol sym = jit_->FindSymbol(entry_function_name);
|
||||
@ -182,9 +184,16 @@ Status CpuExecutable::ExecuteComputeFunction(
|
||||
uint64 start_micros = tensorflow::Env::Default()->NowMicros();
|
||||
|
||||
// Allocate profiling counters for each hlo instruction that we would like to
|
||||
// profile. Allocate an additional profile counter for the entire
|
||||
// computation.
|
||||
std::vector<uint64> profile_counters(hlo_to_profile_idx_.size() + 1);
|
||||
// profile. Even when not Hlo profiling, we allocate a counter for the entire
|
||||
// computation, which we use to update ExecutionProfile below.
|
||||
std::vector<int64>* profile_counters = nullptr;
|
||||
std::vector<int64> profile_counter_for_entry_computation;
|
||||
if (hlo_execution_profile) {
|
||||
profile_counters = hlo_execution_profile->mutable_profile_counters();
|
||||
} else {
|
||||
profile_counters = &profile_counter_for_entry_computation;
|
||||
profile_counter_for_entry_computation.push_back(0);
|
||||
}
|
||||
|
||||
// Call the computation function following the calling convention.
|
||||
std::vector<void*> buffer_pointers;
|
||||
@ -199,7 +208,7 @@ Status CpuExecutable::ExecuteComputeFunction(
|
||||
VLOG(3) << tensorflow::strings::Printf(
|
||||
" func(void* result, void* params[%zu], void* temps[%zu], "
|
||||
"uint64 profile_counters[%zu])",
|
||||
args_array.size(), buffer_pointers.size(), profile_counters.size());
|
||||
args_array.size(), buffer_pointers.size(), profile_counters->size());
|
||||
VLOG(3) << tensorflow::strings::Printf(" result = %p", result_buffer);
|
||||
auto ptr_printer = [](string* out, const void* p) {
|
||||
tensorflow::strings::StrAppend(out, tensorflow::strings::Printf("%p", p));
|
||||
@ -211,11 +220,11 @@ Status CpuExecutable::ExecuteComputeFunction(
|
||||
" temps = [%s]",
|
||||
tensorflow::str_util::Join(buffer_pointers, ", ", ptr_printer).c_str());
|
||||
VLOG(3) << tensorflow::strings::Printf(" profile_counters = %p",
|
||||
profile_counters.data());
|
||||
profile_counters->data());
|
||||
}
|
||||
|
||||
compute_function_(result_buffer, run_options, args_array.data(),
|
||||
buffer_pointers.data(), profile_counters.data());
|
||||
buffer_pointers.data(), profile_counters->data());
|
||||
|
||||
uint64 end_micros = tensorflow::Env::Default()->NowMicros();
|
||||
|
||||
@ -224,20 +233,46 @@ Status CpuExecutable::ExecuteComputeFunction(
|
||||
const double nanoseconds = (end_micros - start_micros) * 1000.0;
|
||||
execution_profile_.set_compute_time_ns(std::max(nanoseconds, 1.0));
|
||||
|
||||
// The last profile counter is used for the computation as a whole.
|
||||
execution_profile_.set_compute_cycle_count(profile_counters.back());
|
||||
}
|
||||
|
||||
if (hlo_execution_profile != nullptr) {
|
||||
hlo_execution_profile->set_total_cycles_executed(
|
||||
*module().entry_computation(), profile_counters.back());
|
||||
|
||||
for (auto hlo_prof_idx : hlo_to_profile_idx_) {
|
||||
const HloInstruction* hlo = hlo_prof_idx.first;
|
||||
uint64 cycles_taken = profile_counters[hlo_prof_idx.second];
|
||||
hlo_execution_profile->SetCyclesTakenBy(hlo, cycles_taken);
|
||||
if (hlo_execution_profile) {
|
||||
execution_profile_.set_compute_cycle_count(
|
||||
hlo_execution_profile->total_cycles_executed(
|
||||
*module().entry_computation()));
|
||||
} else {
|
||||
execution_profile_.set_compute_cycle_count(profile_counters->back());
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static void LogLiveAddresses(
|
||||
const std::unordered_set<const void*>& marked_addresses) {
|
||||
VLOG(3) << "Live addresses in output marking found "
|
||||
<< marked_addresses.size() << " addresses:\n"
|
||||
<< tensorflow::str_util::Join(
|
||||
marked_addresses, ", ", [](string* out, const void* address) {
|
||||
tensorflow::strings::StrAppend(
|
||||
out, tensorflow::strings::Printf("%p", address));
|
||||
});
|
||||
}
|
||||
|
||||
static Status DeallocateTempBuffers(
|
||||
DeviceMemoryAllocator* allocator, se::Stream* stream,
|
||||
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers,
|
||||
const std::unordered_set<const void*>& marked_addresses) {
|
||||
// Keep those marked live because they are referenced by the output of the
|
||||
// computation and are needed by the service. They will be deallocated by the
|
||||
// service.
|
||||
for (size_t i = 0; i < buffers.size(); ++i) {
|
||||
se::DeviceMemoryBase alloc = buffers[i];
|
||||
if (marked_addresses.count(alloc.opaque()) == 0 && !alloc.is_null()) {
|
||||
VLOG(3) << "CpuExecutable deallocating buffer #" << i << " ["
|
||||
<< alloc.opaque() << "]";
|
||||
TF_RETURN_IF_ERROR(
|
||||
allocator->Deallocate(stream->parent()->device_ordinal(), &alloc));
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -263,26 +298,9 @@ StatusOr<perftools::gputools::DeviceMemoryBase> CpuExecutable::ExecuteOnStream(
|
||||
MarkLiveAddressesInOutput(top_level_output.opaque(), result_shape(),
|
||||
&marked_addresses);
|
||||
|
||||
VLOG(3) << "Live addresses in output marking found "
|
||||
<< marked_addresses.size() << " addresses:\n"
|
||||
<< tensorflow::str_util::Join(
|
||||
marked_addresses, ", ", [](string* out, const void* address) {
|
||||
tensorflow::strings::StrAppend(
|
||||
out, tensorflow::strings::Printf("%p", address));
|
||||
});
|
||||
|
||||
// Computation is done - deallocate temp buffers. Keep those marked live
|
||||
// because they are referenced by the output of the computation and are needed
|
||||
// by the service. They will be deallocated by the service.
|
||||
for (size_t i = 0; i < buffers.size(); ++i) {
|
||||
se::DeviceMemoryBase alloc = buffers[i];
|
||||
if (marked_addresses.count(alloc.opaque()) == 0 && !alloc.is_null()) {
|
||||
VLOG(3) << "CpuExecutable deallocating buffer #" << i << " ["
|
||||
<< alloc.opaque() << "]";
|
||||
TF_RETURN_IF_ERROR(memory_allocator->Deallocate(
|
||||
stream->parent()->device_ordinal(), &alloc));
|
||||
}
|
||||
}
|
||||
LogLiveAddresses(marked_addresses);
|
||||
TF_RETURN_IF_ERROR(DeallocateTempBuffers(memory_allocator, stream, buffers,
|
||||
marked_addresses));
|
||||
|
||||
return top_level_output;
|
||||
}
|
||||
@ -360,9 +378,44 @@ StatusOr<perftools::gputools::DeviceMemoryBase>
|
||||
CpuExecutable::ExecuteAsyncOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments) {
|
||||
// TODO(b/30671675): Implement asynchronous execution mode.
|
||||
return Unimplemented(
|
||||
"Asynchronous execution on stream is not yet supported on CPU.");
|
||||
if (hlo_profiling_enabled()) {
|
||||
return Unimplemented(
|
||||
"Asynchronous execution on stream with hlo profiling is not yet "
|
||||
"supported on CPU.");
|
||||
}
|
||||
|
||||
auto* host_stream = dynamic_cast<perftools::gputools::host::HostStream*>(
|
||||
run_options->stream()->implementation());
|
||||
se::Stream* stream = run_options->stream();
|
||||
DeviceMemoryAllocator* memory_allocator = run_options->allocator();
|
||||
std::vector<se::DeviceMemoryBase> buffers(assignment_->Allocations().size());
|
||||
|
||||
TF_RETURN_IF_ERROR(AllocateBuffers(
|
||||
memory_allocator, stream->parent()->device_ordinal(), &buffers));
|
||||
|
||||
// Mark the buffers that are actually live (used in the output) when the
|
||||
// computation finishes executing.
|
||||
std::unordered_set<const void*> marked_addresses;
|
||||
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice,
|
||||
assignment_->GetUniqueTopLevelOutputSlice());
|
||||
se::DeviceMemoryBase top_level_output = buffers[result_slice.index()];
|
||||
MarkLiveAddressesInOutput(top_level_output.opaque(), result_shape(),
|
||||
&marked_addresses);
|
||||
|
||||
LogLiveAddresses(marked_addresses);
|
||||
|
||||
host_stream->EnqueueTask([this, run_options, arguments, buffers,
|
||||
marked_addresses, memory_allocator, stream]() {
|
||||
// Failing a CHECK here is not great, but I don't see an obvious way to
|
||||
// return a failed Status asynchronously.
|
||||
TF_CHECK_OK(ExecuteComputeFunction(&run_options->run_options(), arguments,
|
||||
buffers,
|
||||
/*hlo_execution_profile=*/nullptr));
|
||||
TF_CHECK_OK(DeallocateTempBuffers(memory_allocator, stream, buffers,
|
||||
marked_addresses));
|
||||
});
|
||||
|
||||
return top_level_output;
|
||||
}
|
||||
|
||||
/*static*/ int64 CpuExecutable::ShapeSizeBytes(const Shape& shape) {
|
||||
@ -378,9 +431,5 @@ const PointsToSet& CpuExecutable::GetRootPointsToSet() const {
|
||||
module().entry_computation()->root_instruction());
|
||||
}
|
||||
|
||||
std::unique_ptr<HloCostAnalysis> CpuExecutable::CreateCostAnalysis() const {
|
||||
return MakeUnique<HloCostAnalysis>(ShapeSizeBytes);
|
||||
}
|
||||
|
||||
} // namespace cpu
|
||||
} // namespace xla
|
||||
|
@ -47,12 +47,12 @@ namespace cpu {
|
||||
// architecture, so JIT-ed code and host code share the same ABI.
|
||||
class CpuExecutable : public Executable {
|
||||
public:
|
||||
CpuExecutable(
|
||||
std::unique_ptr<SimpleOrcJIT> jit,
|
||||
std::unique_ptr<const BufferAssignment> assignment,
|
||||
std::unique_ptr<const HloModule> hlo_module,
|
||||
const string& entry_function_name,
|
||||
std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx);
|
||||
CpuExecutable(std::unique_ptr<SimpleOrcJIT> jit,
|
||||
std::unique_ptr<const BufferAssignment> assignment,
|
||||
std::unique_ptr<const HloModule> hlo_module,
|
||||
const string& entry_function_name,
|
||||
std::unique_ptr<HloProfilePrinter> hlo_profile_printer,
|
||||
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map);
|
||||
~CpuExecutable() override {}
|
||||
|
||||
StatusOr<perftools::gputools::DeviceMemoryBase> ExecuteOnStream(
|
||||
@ -85,12 +85,10 @@ class CpuExecutable : public Executable {
|
||||
|
||||
static int64 ShapeSizeBytes(const Shape& shape);
|
||||
|
||||
std::unique_ptr<HloCostAnalysis> CreateCostAnalysis() const override;
|
||||
|
||||
// Type of the computation function we expect in the JIT.
|
||||
using ComputeFunctionType = void (*)(
|
||||
void* /*result*/, const ExecutableRunOptions* /*run_options*/,
|
||||
const void** /*args*/, void** /*temps*/, uint64* /*profile_counters*/);
|
||||
const void** /*args*/, void** /*temps*/, int64* /*profile_counters*/);
|
||||
|
||||
const ComputeFunctionType& compute_function() const {
|
||||
return compute_function_;
|
||||
@ -145,9 +143,6 @@ class CpuExecutable : public Executable {
|
||||
// Entry function name for the computation.
|
||||
const string entry_function_name_;
|
||||
|
||||
// Maps HLOs to their index into the profile counter array.
|
||||
const std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(CpuExecutable);
|
||||
};
|
||||
|
||||
|
@ -59,19 +59,20 @@ ParallelCpuExecutable::ParallelCpuExecutable(
|
||||
std::unique_ptr<const BufferAssignment> assignment,
|
||||
std::unique_ptr<const HloModule> hlo_module,
|
||||
std::unique_ptr<const HloInstructionMap<string>> function_names,
|
||||
std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx,
|
||||
std::unordered_map<const HloInstruction*, std::unique_ptr<unsigned char[]>>
|
||||
aligned_constants)
|
||||
: Executable(std::move(hlo_module)),
|
||||
aligned_constants,
|
||||
std::unique_ptr<HloProfilePrinter> hlo_profile_printer,
|
||||
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map)
|
||||
: Executable(std::move(hlo_module), std::move(hlo_profile_printer),
|
||||
std::move(hlo_profile_index_map)),
|
||||
jit_(std::move(jit)),
|
||||
assignment_(std::move(assignment)),
|
||||
function_names_(std::move(function_names)),
|
||||
hlo_to_profile_idx_(std::move(hlo_to_profile_idx)),
|
||||
aligned_constants_(std::move(aligned_constants)) {}
|
||||
|
||||
// Type of the computation function we expect in the JIT.
|
||||
using ComputeFunctionType = void (*)(void*, const void*, const void**, void**,
|
||||
int64*, uint64*);
|
||||
int64*, int64*);
|
||||
|
||||
// Given a pointer to an output buffer (following the CPU JIT calling
|
||||
// conventions), mark addresses that are "live". The initial pointer itself is
|
||||
@ -106,7 +107,7 @@ class Executor {
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
std::list<HloInstruction*>* pending,
|
||||
HloInstructionMap<const void*>* results, void** temps_array,
|
||||
uint64* profile_counters_array, const BufferAssignment* assignment)
|
||||
int64* profile_counters_array, const BufferAssignment* assignment)
|
||||
: functions_(functions),
|
||||
run_options_(run_options),
|
||||
pending_(pending),
|
||||
@ -147,7 +148,7 @@ class Executor {
|
||||
std::list<HloInstruction*>* pending_;
|
||||
HloInstructionMap<const void*>* results_;
|
||||
void** temps_array_;
|
||||
uint64* profile_counters_array_;
|
||||
int64* profile_counters_array_;
|
||||
tensorflow::thread::ThreadPool* thread_pool_;
|
||||
const BufferAssignment* assignment_;
|
||||
|
||||
@ -389,9 +390,11 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions(
|
||||
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers,
|
||||
HloExecutionProfile* hlo_execution_profile) {
|
||||
// Allocate profiling counters for each hlo instruction that we would like to
|
||||
// profile. Allocate an additional profile counter for the entire
|
||||
// computation.
|
||||
std::vector<uint64> profile_counters(hlo_to_profile_idx_.size() + 1);
|
||||
// profile.
|
||||
std::vector<int64>* profile_counters = nullptr;
|
||||
if (hlo_execution_profile) {
|
||||
profile_counters = hlo_execution_profile->mutable_profile_counters();
|
||||
}
|
||||
|
||||
std::vector<void*> buffer_pointers;
|
||||
buffer_pointers.reserve(buffers.size());
|
||||
@ -441,9 +444,9 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions(
|
||||
// For example, if we expect a library conv/matmul call to run at max
|
||||
// concurrency, we should not dispatch runnable instructions until the
|
||||
// library call is finished (to avoid expensive cache invalidation).
|
||||
Executor executor(functions, run_options, &pending, &results,
|
||||
buffer_pointers.data(), profile_counters.data(),
|
||||
assignment_.get());
|
||||
Executor executor(
|
||||
functions, run_options, &pending, &results, buffer_pointers.data(),
|
||||
profile_counters ? profile_counters->data() : nullptr, assignment_.get());
|
||||
|
||||
TF_RETURN_IF_ERROR(executor.Run());
|
||||
|
||||
@ -453,18 +456,6 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions(
|
||||
tensorflow::mutex_lock lock(mutex_);
|
||||
double nanoseconds = (end_micros - start_micros) * 1000.0;
|
||||
execution_profile_.set_compute_time_ns(std::max(nanoseconds, 1.0));
|
||||
// The last profile counter is used for the computation as a whole.
|
||||
execution_profile_.set_compute_cycle_count(profile_counters.back());
|
||||
}
|
||||
if (hlo_execution_profile != nullptr) {
|
||||
hlo_execution_profile->set_total_cycles_executed(entry_computation,
|
||||
profile_counters.back());
|
||||
|
||||
for (auto hlo_prof_idx : hlo_to_profile_idx_) {
|
||||
const HloInstruction* hlo = hlo_prof_idx.first;
|
||||
uint64 cycles_taken = profile_counters[hlo_prof_idx.second];
|
||||
hlo_execution_profile->SetCyclesTakenBy(hlo, cycles_taken);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
@ -618,10 +609,5 @@ const PointsToSet& ParallelCpuExecutable::GetRootPointsToSet() const {
|
||||
module().entry_computation()->root_instruction());
|
||||
}
|
||||
|
||||
std::unique_ptr<HloCostAnalysis> ParallelCpuExecutable::CreateCostAnalysis()
|
||||
const {
|
||||
return MakeUnique<HloCostAnalysis>(ShapeSizeBytes);
|
||||
}
|
||||
|
||||
} // namespace cpu
|
||||
} // namespace xla
|
||||
|
@ -52,10 +52,11 @@ class ParallelCpuExecutable : public Executable {
|
||||
std::unique_ptr<const BufferAssignment> assignment,
|
||||
std::unique_ptr<const HloModule> hlo_module,
|
||||
std::unique_ptr<const HloInstructionMap<string>> function_names,
|
||||
std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx,
|
||||
std::unordered_map<const HloInstruction*,
|
||||
std::unique_ptr<unsigned char[]>>
|
||||
aligned_constants);
|
||||
aligned_constants,
|
||||
std::unique_ptr<HloProfilePrinter> hlo_profile_printer,
|
||||
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map);
|
||||
~ParallelCpuExecutable() override {}
|
||||
|
||||
StatusOr<perftools::gputools::DeviceMemoryBase> ExecuteOnStream(
|
||||
@ -95,8 +96,6 @@ class ParallelCpuExecutable : public Executable {
|
||||
"Equality test on CPU parallel executable is not implemented.");
|
||||
}
|
||||
|
||||
std::unique_ptr<HloCostAnalysis> CreateCostAnalysis() const override;
|
||||
|
||||
private:
|
||||
// Allocate buffers required for execution and assign them to the elements of
|
||||
// "buffers". "buffers" should be sized to the number of buffers in buffer
|
||||
@ -143,9 +142,6 @@ class ParallelCpuExecutable : public Executable {
|
||||
// Map containing the JITted function names for each HLO instruction.
|
||||
const std::unique_ptr<const HloInstructionMap<string>> function_names_;
|
||||
|
||||
// Maps HLOs to their index into the profile counter array.
|
||||
const std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx_;
|
||||
|
||||
// Map from HLO Constant instructions to a pointer to their literal data.
|
||||
// The data stored in the protocol buffer might be insufficiently aligned,
|
||||
// we create a sufficiently aligned copy and store it in this map.
|
||||
|
@ -44,8 +44,15 @@ namespace xla {
|
||||
// interface that is used for launching compiled programs across platforms.
|
||||
class Executable {
|
||||
public:
|
||||
explicit Executable(std::unique_ptr<const HloModule> hlo_module)
|
||||
: hlo_module_(std::move(hlo_module)) {}
|
||||
explicit Executable(std::unique_ptr<const HloModule> hlo_module,
|
||||
std::unique_ptr<HloProfilePrinter> hlo_profile_printer,
|
||||
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map)
|
||||
: hlo_module_(std::move(hlo_module)),
|
||||
hlo_profile_printer_(std::move(hlo_profile_printer)),
|
||||
hlo_profile_index_map_(std::move(hlo_profile_index_map)) {
|
||||
CHECK_EQ(hlo_profile_printer_.get() == nullptr,
|
||||
hlo_profile_index_map_.get() == nullptr);
|
||||
}
|
||||
virtual ~Executable() {}
|
||||
|
||||
// Enqueues the compilation result on the provided stream, passing the given
|
||||
@ -123,12 +130,20 @@ class Executable {
|
||||
"Equality test on this executable is not implemented.");
|
||||
}
|
||||
|
||||
const HloProfilePrinter& hlo_profile_printer() const {
|
||||
CHECK(hlo_profiling_enabled());
|
||||
return *hlo_profile_printer_;
|
||||
}
|
||||
|
||||
const HloProfileIndexMap& hlo_profile_index_map() const {
|
||||
CHECK(hlo_profiling_enabled());
|
||||
return *hlo_profile_index_map_;
|
||||
}
|
||||
|
||||
// Returns whether this executable was compiled with HLO profilings support
|
||||
// enabled. If not, the caller should not expect an hlo_execution_profile
|
||||
// passed to ExecuteOnStream above to be populated during execution.
|
||||
bool hlo_profiling_enabled() const {
|
||||
return hlo_module_->config().hlo_profiling_enabled();
|
||||
}
|
||||
bool hlo_profiling_enabled() const { return hlo_profile_printer_ != nullptr; }
|
||||
|
||||
const HloModule& module() const { return *hlo_module_; }
|
||||
|
||||
@ -160,10 +175,6 @@ class Executable {
|
||||
static Status DumpToDirectory(const string& directory_path, string filename,
|
||||
const SessionModule& session_module);
|
||||
|
||||
// Returns a cost analysis object appropriate for the platform on which this
|
||||
// executable can run.
|
||||
virtual std::unique_ptr<HloCostAnalysis> CreateCostAnalysis() const = 0;
|
||||
|
||||
protected:
|
||||
mutable tensorflow::mutex mutex_;
|
||||
|
||||
@ -181,6 +192,9 @@ class Executable {
|
||||
// Execution count, used to generate a unique filename for each dumped
|
||||
// execution.
|
||||
int64 execution_count_ = 0;
|
||||
|
||||
std::unique_ptr<HloProfilePrinter> hlo_profile_printer_;
|
||||
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map_;
|
||||
};
|
||||
|
||||
template <typename ReturnT, typename ArgT>
|
||||
@ -200,7 +214,8 @@ StatusOr<ReturnT> Executable::ExecuteOnStreamWrapper(
|
||||
std::unique_ptr<HloExecutionProfile> profile_ptr =
|
||||
module_config().debug_options().xla_hlo_profile() &&
|
||||
hlo_profiling_enabled()
|
||||
? MakeUnique<HloExecutionProfile>(module(), *CreateCostAnalysis())
|
||||
? MakeUnique<HloExecutionProfile>(&hlo_profile_printer(),
|
||||
&hlo_profile_index_map())
|
||||
: nullptr;
|
||||
|
||||
auto return_value =
|
||||
|
@ -465,10 +465,20 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
|
||||
VLOG(2) << "Printing the thunk schedule...";
|
||||
XLA_VLOG_LINES(2, thunk_schedule->ToString());
|
||||
|
||||
auto* gpu_executable =
|
||||
new GpuExecutable(ptx, cubin, {cc_major, cc_minor},
|
||||
std::move(thunk_schedule), std::move(module),
|
||||
std::move(buffer_assignment), ShapeSizeBytesFunction());
|
||||
std::unique_ptr<HloProfileIndexMap> profile_index_map;
|
||||
std::unique_ptr<HloProfilePrinter> profile_printer;
|
||||
|
||||
if (module->config().hlo_profiling_enabled()) {
|
||||
HloCostAnalysis cost_analysis(ShapeSizeBytesFunction());
|
||||
profile_index_map = MakeUnique<HloProfileIndexMap>(*module);
|
||||
profile_printer =
|
||||
CreateHloProfilePrinter(*profile_index_map, cost_analysis);
|
||||
}
|
||||
|
||||
auto* gpu_executable = new GpuExecutable(
|
||||
ptx, cubin, {cc_major, cc_minor}, std::move(thunk_schedule),
|
||||
std::move(module), std::move(buffer_assignment),
|
||||
std::move(profile_printer), std::move(profile_index_map));
|
||||
if (embed_ir_in_executable) {
|
||||
DCHECK_NE("", ir_module_string_before_opt);
|
||||
gpu_executable->set_ir_module_string(ir_module_string_before_opt);
|
||||
|
@ -113,14 +113,15 @@ GpuExecutable::GpuExecutable(
|
||||
std::unique_ptr<const ThunkSchedule> thunk_schedule,
|
||||
std::unique_ptr<const HloModule> hlo_module,
|
||||
std::unique_ptr<const BufferAssignment> assignment,
|
||||
HloCostAnalysis::ShapeSizeFunction shape_size_function)
|
||||
: Executable(std::move(hlo_module)),
|
||||
std::unique_ptr<HloProfilePrinter> hlo_profile_printer,
|
||||
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map)
|
||||
: Executable(std::move(hlo_module), std::move(hlo_profile_printer),
|
||||
std::move(hlo_profile_index_map)),
|
||||
ptx_(ptx),
|
||||
cubin_(cubin),
|
||||
compute_capability_(compute_capability),
|
||||
thunk_schedule_(std::move(thunk_schedule)),
|
||||
assignment_(std::move(assignment)),
|
||||
shape_size_function_(std::move(shape_size_function)) {}
|
||||
assignment_(std::move(assignment)) {}
|
||||
|
||||
Status GpuExecutable::ExecuteThunks(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
@ -358,9 +359,5 @@ const PointsToSet& GpuExecutable::GetRootPointsToSet() const {
|
||||
module().entry_computation()->root_instruction());
|
||||
}
|
||||
|
||||
std::unique_ptr<HloCostAnalysis> GpuExecutable::CreateCostAnalysis() const {
|
||||
return MakeUnique<HloCostAnalysis>(shape_size_function_);
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
@ -54,7 +54,8 @@ class GpuExecutable : public Executable {
|
||||
std::unique_ptr<const ThunkSchedule> thunk_schedule,
|
||||
std::unique_ptr<const HloModule> hlo_module,
|
||||
std::unique_ptr<const BufferAssignment> assignment,
|
||||
HloCostAnalysis::ShapeSizeFunction shape_size_function);
|
||||
std::unique_ptr<HloProfilePrinter> hlo_profile_printer,
|
||||
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map);
|
||||
|
||||
// This should be called after set_ir_module_string.
|
||||
const string& ir_module_string() const { return ir_module_string_; }
|
||||
@ -95,8 +96,6 @@ class GpuExecutable : public Executable {
|
||||
return Unimplemented("Equality test on GPU executable is not implemented.");
|
||||
}
|
||||
|
||||
std::unique_ptr<HloCostAnalysis> CreateCostAnalysis() const override;
|
||||
|
||||
private:
|
||||
// If `block_host_until_done` is false, execution will not block the host
|
||||
// until the kernels have completed. This is used as an optimization for
|
||||
@ -140,9 +139,6 @@ class GpuExecutable : public Executable {
|
||||
// memory for every output/temp buffers.
|
||||
const std::unique_ptr<const BufferAssignment> assignment_;
|
||||
|
||||
// Function to compute the size of a given Shape, in bytes.
|
||||
const HloCostAnalysis::ShapeSizeFunction shape_size_function_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(GpuExecutable);
|
||||
};
|
||||
|
||||
|
@ -40,7 +40,7 @@ HloProfileIndexMap::HloProfileIndexMap(const HloModule& module) {
|
||||
}
|
||||
}
|
||||
|
||||
static HloProfilePrinter CreateOwnedHloProfilePrinter(
|
||||
std::unique_ptr<HloProfilePrinter> CreateHloProfilePrinter(
|
||||
const HloProfileIndexMap& hlo_profile_index_map,
|
||||
const HloCostAnalysis& cost_analysis) {
|
||||
using HloComputationInfo = HloProfilePrinter::HloComputationInfo;
|
||||
@ -108,15 +108,15 @@ static HloProfilePrinter CreateOwnedHloProfilePrinter(
|
||||
delete[] computation_infos;
|
||||
};
|
||||
|
||||
return HloProfilePrinter(computation_infos,
|
||||
hlo_profile_index_map.computation_count(), deleter);
|
||||
return MakeUnique<HloProfilePrinter>(
|
||||
computation_infos, hlo_profile_index_map.computation_count(), deleter);
|
||||
}
|
||||
|
||||
HloExecutionProfile::HloExecutionProfile(const HloModule& module,
|
||||
const HloCostAnalysis& cost_analysis)
|
||||
: hlo_profile_index_map_(module),
|
||||
hlo_profile_printer_(
|
||||
CreateOwnedHloProfilePrinter(hlo_profile_index_map_, cost_analysis)),
|
||||
HloExecutionProfile::HloExecutionProfile(
|
||||
const HloProfilePrinter* hlo_profile_printer,
|
||||
const HloProfileIndexMap* hlo_profile_index_map)
|
||||
: hlo_profile_printer_(*hlo_profile_printer),
|
||||
hlo_profile_index_map_(*hlo_profile_index_map),
|
||||
profile_counters_(
|
||||
/*count*/ hlo_profile_index_map_.total_count(),
|
||||
/*value*/ 0) {}
|
||||
@ -131,10 +131,4 @@ uint64 HloExecutionProfile::GetCyclesTakenBy(const HloInstruction& hlo) const {
|
||||
return profile_counters_[hlo_profile_index_map_.GetProfileIndexFor(hlo)];
|
||||
}
|
||||
|
||||
string HloExecutionProfile::ToString(
|
||||
const DeviceDescription& device_description) const {
|
||||
return hlo_profile_printer_.ToString(profile_counters_.data(),
|
||||
device_description.clock_rate_ghz());
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -77,6 +77,11 @@ class HloProfileIndexMap {
|
||||
std::unordered_map<const HloComputation*, int64> computation_to_profile_idx_;
|
||||
};
|
||||
|
||||
// Create an instance of `HloProfilePrinter` that owns its memory.
|
||||
std::unique_ptr<HloProfilePrinter> CreateHloProfilePrinter(
|
||||
const HloProfileIndexMap& hlo_profile_index_map,
|
||||
const HloCostAnalysis& cost_analysis);
|
||||
|
||||
// Describes how much time each HLO operation took.
|
||||
//
|
||||
// Each HloComputation takes a certain number of cycles. This class helps break
|
||||
@ -85,8 +90,8 @@ class HloExecutionProfile {
|
||||
public:
|
||||
using DeviceDescription = perftools::gputools::DeviceDescription;
|
||||
|
||||
HloExecutionProfile(const HloModule& module,
|
||||
const HloCostAnalysis& cost_analysis);
|
||||
HloExecutionProfile(const HloProfilePrinter* hlo_profile_printer,
|
||||
const HloProfileIndexMap* hlo_profile_index_map);
|
||||
|
||||
// Record how many cycles this HLO took to execute.
|
||||
void SetCyclesTakenBy(const HloInstruction* hlo, uint64 cycles_taken);
|
||||
@ -114,15 +119,16 @@ class HloExecutionProfile {
|
||||
// for the operations in a given computation. Returns an empty string if it
|
||||
// wasn't possible to generate a printable version. cost_analysis should be a
|
||||
// clean analysis that can be used to visit the computation.
|
||||
string ToString(const DeviceDescription& device_description) const;
|
||||
string ToString(const DeviceDescription& device_description) const {
|
||||
return hlo_profile_printer_.ToString(profile_counters_.data(),
|
||||
device_description.clock_rate_ghz());
|
||||
}
|
||||
|
||||
std::vector<int64>* mutable_profile_counters() { return &profile_counters_; }
|
||||
|
||||
private:
|
||||
// hlo_profile_index_map_ maps an Hlo entity (computation or instruction) to
|
||||
// an index in profile_counters_.
|
||||
HloProfileIndexMap hlo_profile_index_map_;
|
||||
|
||||
// Used to print profile_counters_ in a human readable form.
|
||||
HloProfilePrinter hlo_profile_printer_;
|
||||
const HloProfilePrinter& hlo_profile_printer_;
|
||||
const HloProfileIndexMap& hlo_profile_index_map_;
|
||||
|
||||
// Stores per-Hlo profile counters. This is the only thing that changes when
|
||||
// we execute an XLA computation.
|
||||
|
@ -72,7 +72,11 @@ TEST_F(HloExecutionProfileTest, Basic) {
|
||||
};
|
||||
|
||||
HloCostAnalysis cost_analysis(shape_size_function);
|
||||
HloExecutionProfile execution_profile(*hlo_module, cost_analysis);
|
||||
HloProfileIndexMap profile_index_map(*hlo_module);
|
||||
std::unique_ptr<HloProfilePrinter> profile_printer =
|
||||
CreateHloProfilePrinter(profile_index_map, cost_analysis);
|
||||
HloExecutionProfile execution_profile(profile_printer.get(),
|
||||
&profile_index_map);
|
||||
|
||||
const int64 add_cycles = 1000;
|
||||
const int64 dot_cycles = 4000;
|
||||
|
@ -42,7 +42,8 @@ namespace sep = ::perftools::gputools::interpreter;
|
||||
|
||||
InterpreterExecutable::InterpreterExecutable(
|
||||
std::unique_ptr<const HloModule> hlo_module)
|
||||
: Executable(std::move(hlo_module)) {}
|
||||
: Executable(std::move(hlo_module), /*hlo_profile_printer=*/nullptr,
|
||||
/*hlo_profile_index_map=*/nullptr) {}
|
||||
|
||||
InterpreterExecutable::~InterpreterExecutable() {}
|
||||
|
||||
@ -156,10 +157,5 @@ StatusOr<se::DeviceMemoryBase> InterpreterExecutable::ExecuteAsyncOnStream(
|
||||
return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
|
||||
}
|
||||
|
||||
std::unique_ptr<HloCostAnalysis> InterpreterExecutable::CreateCostAnalysis()
|
||||
const {
|
||||
return MakeUnique<HloCostAnalysis>(ShapeSizeBytes);
|
||||
}
|
||||
|
||||
} // namespace interpreter
|
||||
} // namespace xla
|
||||
|
@ -61,8 +61,6 @@ class InterpreterExecutable : public Executable {
|
||||
|
||||
static int64 ShapeSizeBytes(const Shape& shape);
|
||||
|
||||
std::unique_ptr<HloCostAnalysis> CreateCostAnalysis() const override;
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(InterpreterExecutable);
|
||||
};
|
||||
|
@ -575,12 +575,13 @@ Service::ExecuteParallelAndRegisterResult(
|
||||
// profile.
|
||||
for (auto& index_to_profiled_stream : index_to_profiled_streams) {
|
||||
int64 device = index_to_profiled_stream.first;
|
||||
auto& module = executables[device]->module();
|
||||
se::Stream* stream = index_to_profiled_stream.second;
|
||||
HloExecutionProfile hlo_profile(module,
|
||||
*executables[device]->CreateCostAnalysis());
|
||||
TF_RETURN_IF_ERROR(executables[device]->PopulateExecutionProfile(
|
||||
&hlo_profile, stream->parent()));
|
||||
Executable* executable = executables[device];
|
||||
const HloModule& module = executable->module();
|
||||
HloExecutionProfile hlo_profile(&executable->hlo_profile_printer(),
|
||||
&executable->hlo_profile_index_map());
|
||||
TF_RETURN_IF_ERROR(
|
||||
executable->PopulateExecutionProfile(&hlo_profile, stream->parent()));
|
||||
XLA_LOG_LINES(
|
||||
tensorflow::INFO,
|
||||
hlo_profile.ToString(streams[0]->parent()->GetDeviceDescription()));
|
||||
|
@ -773,6 +773,11 @@ xla_test(
|
||||
xla_test(
|
||||
name = "bfloat16_test",
|
||||
srcs = ["bfloat16_test.cc"],
|
||||
blacklisted_backends = [
|
||||
"cpu",
|
||||
"cpu_parallel",
|
||||
"gpu",
|
||||
],
|
||||
shard_count = 40,
|
||||
deps = [
|
||||
":test_utils",
|
||||
@ -1343,6 +1348,7 @@ xla_test(
|
||||
srcs = ["client_test.cc"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
|
@ -51,8 +51,7 @@ class Bfloat16Test : public ClientLibraryTestBase {
|
||||
const ErrorSpec error_spec_{0.001, 0.001};
|
||||
};
|
||||
|
||||
XLA_TEST_F(Bfloat16Test, DISABLED_ON_GPU(DISABLED_ON_CPU_PARALLEL(
|
||||
DISABLED_ON_CPU(ScalarOperation)))) {
|
||||
XLA_TEST_F(Bfloat16Test, ScalarOperation) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto x = builder.ConstantR0<bfloat16>(static_cast<bfloat16>(2.0f));
|
||||
auto y = builder.ConstantR0<bfloat16>(static_cast<bfloat16>(1.0f));
|
||||
@ -62,8 +61,7 @@ XLA_TEST_F(Bfloat16Test, DISABLED_ON_GPU(DISABLED_ON_CPU_PARALLEL(
|
||||
error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(Bfloat16Test, DISABLED_ON_GPU(DISABLED_ON_CPU_PARALLEL(
|
||||
DISABLED_ON_CPU(NegateScalarF16)))) {
|
||||
XLA_TEST_F(Bfloat16Test, NegateScalarF16) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
builder.Neg(builder.ConstantR0<bfloat16>(static_cast<bfloat16>(2.1f)));
|
||||
|
||||
@ -71,5 +69,83 @@ XLA_TEST_F(Bfloat16Test, DISABLED_ON_GPU(DISABLED_ON_CPU_PARALLEL(
|
||||
error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(Bfloat16Test, BatchNormTraining) {
|
||||
const int kFeatureIndex = 2;
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
|
||||
auto operand = builder.ConstantR4FromArray4D<bfloat16>(
|
||||
{{{{static_cast<bfloat16>(1.f)}, {static_cast<bfloat16>(2.f)}},
|
||||
{{static_cast<bfloat16>(3.f)}, {static_cast<bfloat16>(4.f)}}},
|
||||
{{{static_cast<bfloat16>(5.f)}, {static_cast<bfloat16>(6.f)}},
|
||||
{{static_cast<bfloat16>(7.f)}, {static_cast<bfloat16>(8.f)}}}});
|
||||
|
||||
auto scale = builder.ConstantR1<bfloat16>(
|
||||
{static_cast<bfloat16>(2.0f), static_cast<bfloat16>(3.0f)});
|
||||
|
||||
auto offset = builder.ConstantR1<bfloat16>(
|
||||
{static_cast<bfloat16>(1.0f), static_cast<bfloat16>(2.0f)});
|
||||
|
||||
auto tuple = builder.BatchNormTraining(operand, scale, offset,
|
||||
/*epsilon=*/0.001, kFeatureIndex);
|
||||
|
||||
auto expected = *Literal::MakeTuple(
|
||||
{Literal::CreateR4<bfloat16>(
|
||||
{{{{static_cast<bfloat16>(-1.7f)}, {static_cast<bfloat16>(-2.04f)}},
|
||||
{{static_cast<bfloat16>(0.105f)}, {static_cast<bfloat16>(0.65f)}}},
|
||||
{{{static_cast<bfloat16>(1.89f)}, {static_cast<bfloat16>(3.35f)}},
|
||||
{{static_cast<bfloat16>(3.7f)}, {static_cast<bfloat16>(6.04f)}}}})
|
||||
.get(),
|
||||
Literal::CreateR1<bfloat16>(
|
||||
{static_cast<bfloat16>(4), static_cast<bfloat16>(5)})
|
||||
.get(),
|
||||
Literal::CreateR1<bfloat16>(
|
||||
{static_cast<bfloat16>(5), static_cast<bfloat16>(5)})
|
||||
.get()});
|
||||
|
||||
ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01));
|
||||
}
|
||||
|
||||
XLA_TEST_F(Bfloat16Test, BatchNormGrad) {
|
||||
const int kFeatureIndex = 2;
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
|
||||
auto operand = builder.ConstantR4FromArray4D<bfloat16>(
|
||||
Array4D<bfloat16>(2, 2, 2, 1, static_cast<bfloat16>(0.0f)));
|
||||
|
||||
auto scale = builder.ConstantR1<bfloat16>(
|
||||
{static_cast<bfloat16>(1.0f), static_cast<bfloat16>(1.0f)});
|
||||
|
||||
auto mean = builder.ConstantR1<bfloat16>(
|
||||
{static_cast<bfloat16>(0.0f), static_cast<bfloat16>(0.0f)});
|
||||
|
||||
auto var = builder.ConstantR1<bfloat16>(
|
||||
{static_cast<bfloat16>(1.0f), static_cast<bfloat16>(1.0f)});
|
||||
|
||||
auto grad_output = builder.ConstantR4FromArray4D<bfloat16>(
|
||||
{{{{static_cast<bfloat16>(1.f)}, {static_cast<bfloat16>(2.f)}},
|
||||
{{static_cast<bfloat16>(3.f)}, {static_cast<bfloat16>(4.f)}}},
|
||||
{{{static_cast<bfloat16>(5.f)}, {static_cast<bfloat16>(6.f)}},
|
||||
{{static_cast<bfloat16>(7.f)}, {static_cast<bfloat16>(8.f)}}}});
|
||||
|
||||
builder.BatchNormGrad(operand, scale, mean, var, grad_output,
|
||||
/*epsilon=*/0.0, kFeatureIndex);
|
||||
|
||||
auto expected = *Literal::MakeTuple(
|
||||
{Literal::CreateR4<bfloat16>(
|
||||
{{{{static_cast<bfloat16>(-3.f)}, {static_cast<bfloat16>(-3.f)}},
|
||||
{{static_cast<bfloat16>(-1.f)}, {static_cast<bfloat16>(-1.f)}}},
|
||||
{{{static_cast<bfloat16>(1.f)}, {static_cast<bfloat16>(1.f)}},
|
||||
{{static_cast<bfloat16>(3.f)}, {static_cast<bfloat16>(3.f)}}}})
|
||||
.get(),
|
||||
Literal::CreateR1<bfloat16>(
|
||||
{static_cast<bfloat16>(0), static_cast<bfloat16>(0)})
|
||||
.get(),
|
||||
Literal::CreateR1<bfloat16>(
|
||||
{static_cast<bfloat16>(16), static_cast<bfloat16>(20)})
|
||||
.get()});
|
||||
|
||||
ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -29,6 +29,7 @@ def xla_test(name,
|
||||
deps,
|
||||
xla_test_library_deps=[],
|
||||
backends=[],
|
||||
blacklisted_backends=[],
|
||||
args=[],
|
||||
tags=[],
|
||||
copts=[],
|
||||
@ -92,17 +93,24 @@ def xla_test(name,
|
||||
backends: A list of backends to generate tests for. Supported
|
||||
values: "cpu", "cpu_parallel", "gpu". If this list is empty, the test will
|
||||
be generated for all supported backends.
|
||||
blacklisted_backends: A list of backends to NOT generate tests for.
|
||||
args: Test arguments for the target.
|
||||
tags: Tags for the target.
|
||||
backend_args: A dict mapping backend name to list of additional args to
|
||||
use for that target.
|
||||
copts: Additional copts to pass to the build.
|
||||
data: Additional data to pass to the build.
|
||||
backend_tags: A dict mapping backend name to list of additional tags to
|
||||
use for that target.
|
||||
backend_args: A dict mapping backend name to list of additional args to
|
||||
use for that target.
|
||||
**kwargs: Additional keyword arguments to pass to native.cc_test.
|
||||
"""
|
||||
test_names = []
|
||||
if not backends:
|
||||
backends = all_backends
|
||||
|
||||
backends = [backend for backend in backends
|
||||
if backend not in blacklisted_backends]
|
||||
|
||||
native.cc_library(
|
||||
name="%s_lib" % name,
|
||||
srcs=srcs,
|
||||
|
@ -20,10 +20,12 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/global_data.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
@ -42,26 +44,26 @@ TEST_F(ClientTest, ExecuteWithLayout) {
|
||||
for (const std::vector<int64>& transfer_layout : layouts) {
|
||||
b.Add(b.ConstantR2<int32>({{1, 2}, {3, 4}}),
|
||||
b.ConstantR2<int32>({{10, 20}, {30, 40}}));
|
||||
auto computation = b.Build();
|
||||
ASSERT_TRUE(computation.ok()) << computation.status();
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build());
|
||||
|
||||
ExecutionOptions execution_options = execution_options_;
|
||||
*execution_options.mutable_shape_with_output_layout() =
|
||||
ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2},
|
||||
execute_layout);
|
||||
std::unique_ptr<GlobalData> data =
|
||||
client_->Execute(computation.ValueOrDie(), {}, &execution_options)
|
||||
.ConsumeValueOrDie();
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<GlobalData> data,
|
||||
client_->Execute(computation, {}, &execution_options));
|
||||
|
||||
std::unique_ptr<Literal> expected_literal =
|
||||
Literal::CreateR2WithLayout<int32>(
|
||||
{{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout));
|
||||
|
||||
auto computed = client_->Transfer(*data, &expected_literal->shape());
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto computed, client_->Transfer(*data, &expected_literal->shape()));
|
||||
|
||||
LiteralTestUtil::AssertEqualShapesAndLayouts(
|
||||
expected_literal->shape(), computed.ValueOrDie()->shape());
|
||||
LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie());
|
||||
LiteralTestUtil::AssertEqualShapesAndLayouts(expected_literal->shape(),
|
||||
computed->shape());
|
||||
LiteralTestUtil::ExpectEqual(*expected_literal, *computed);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -72,8 +74,7 @@ TEST_F(ClientTest, ExecuteWithTupleLayout) {
|
||||
b.Tuple({b.ConstantR2<int32>({{1, 2}, {3, 4}}),
|
||||
b.ConstantR2<int32>({{10, 20}, {30, 40}})});
|
||||
|
||||
auto computation = b.Build();
|
||||
ASSERT_TRUE(computation.ok()) << computation.status();
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build());
|
||||
|
||||
ExecutionOptions execution_options = execution_options_;
|
||||
// Create a result shape with one element column major and the other row
|
||||
@ -85,10 +86,9 @@ TEST_F(ClientTest, ExecuteWithTupleLayout) {
|
||||
ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2},
|
||||
/*minor_to_major=*/{1, 0})});
|
||||
|
||||
auto result =
|
||||
client_
|
||||
->ExecuteAndTransfer(computation.ValueOrDie(), {}, &execution_options)
|
||||
.ConsumeValueOrDie();
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto result,
|
||||
client_->ExecuteAndTransfer(computation, {}, &execution_options));
|
||||
LiteralTestUtil::ExpectR2Equal<int32>({{1, 2}, {3, 4}},
|
||||
result->tuple_literals(0));
|
||||
LiteralTestUtil::ExpectR2Equal<int32>({{10, 20}, {30, 40}},
|
||||
@ -107,5 +107,42 @@ TEST_F(ClientTest, ExecuteWithTupleLayout) {
|
||||
/*minor_to_major=*/{1, 0})));
|
||||
}
|
||||
|
||||
TEST_F(ClientTest, DISABLED_ON_CPU_PARALLEL(DISABLED_ON_GPU(ExecuteParallel))) {
|
||||
Computation add_with_one_arg, mul_with_two_args, dot_with_one_arg;
|
||||
Shape shape = ShapeUtil::MakeShape(S32, {2, 2});
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<GlobalData> const_arg,
|
||||
client_->TransferToServer(*Literal::CreateR2<int32>({{5, 6}, {7, 8}})));
|
||||
|
||||
ComputationBuilder b(client_, TestName() + ".add");
|
||||
b.Add(b.Parameter(0, shape, "param_0"),
|
||||
b.ConstantR2<int32>({{1, 2}, {3, 4}}));
|
||||
TF_ASSERT_OK_AND_ASSIGN(add_with_one_arg, b.Build());
|
||||
|
||||
// We can't really test parallel execution on CPU since all of the cores in a
|
||||
// CPU are presented as a single device. So for now we test "parallel"
|
||||
// execution on a single device.
|
||||
std::vector<Client::ComputationInstance> computation_instances;
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::vector<xla::DeviceHandle> devices,
|
||||
client_->GetDeviceHandles(1));
|
||||
ASSERT_EQ(devices.size(), 1);
|
||||
|
||||
ExecutionOptions options = execution_options_;
|
||||
*options.add_device_handles() = devices[0];
|
||||
computation_instances.push_back(Client::ComputationInstance(
|
||||
add_with_one_arg, {const_arg.get()}, options, nullptr));
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto results,
|
||||
client_->ExecuteParallel(computation_instances));
|
||||
auto expected_result = Literal::CreateR2<int32>({{6, 8}, {10, 12}});
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto result_literal,
|
||||
client_->Transfer(*results[0], &expected_result->shape()));
|
||||
|
||||
LiteralTestUtil::ExpectEqual(*expected_result, *result_literal);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -37,7 +37,7 @@ set_target_properties(lib_tf PROPERTIES IMPORTED_LOCATION
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DIS_SLIM_BUILD \
|
||||
-std=c++11 -fno-rtti -fno-exceptions \
|
||||
-O2 -Wno-narrowing -fomit-frame-pointer \
|
||||
-mfpu=neon -mfloat-abi=softfp -fPIE \
|
||||
-mfpu=neon -mfloat-abi=softfp -fPIE -fPIC \
|
||||
-ftemplate-depth=900 \
|
||||
-DGOOGLE_PROTOBUF_NO_RTTI \
|
||||
-DGOOGLE_PROTOBUF_NO_STATIC_INITIALIZER")
|
||||
|
@ -16,7 +16,7 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/node-stats.h"
|
||||
#include "tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h"
|
||||
#include "tensorflow/contrib/boosted_trees/proto/split_info.pb.h"
|
||||
#include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h"
|
||||
#include "tensorflow/core/framework/device_base.h"
|
||||
|
@ -408,7 +408,7 @@ tf_cc_test(
|
||||
# Learner/stochastic
|
||||
cc_library(
|
||||
name = "gradient-stats",
|
||||
hdrs = ["learner/stochastic/stats/gradient-stats.h"],
|
||||
hdrs = ["learner/common/stats/gradient-stats.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//third_party/eigen3",
|
||||
@ -417,7 +417,7 @@ cc_library(
|
||||
|
||||
cc_library(
|
||||
name = "node-stats",
|
||||
hdrs = ["learner/stochastic/stats/node-stats.h"],
|
||||
hdrs = ["learner/common/stats/node-stats.h"],
|
||||
deps = [
|
||||
":gradient-stats",
|
||||
"//tensorflow/contrib/boosted_trees/proto:learner_proto_cc",
|
||||
@ -429,7 +429,7 @@ cc_library(
|
||||
|
||||
cc_library(
|
||||
name = "split-stats",
|
||||
hdrs = ["learner/stochastic/stats/split-stats.h"],
|
||||
hdrs = ["learner/common/stats/split-stats.h"],
|
||||
deps = [
|
||||
":node-stats",
|
||||
],
|
||||
@ -437,7 +437,7 @@ cc_library(
|
||||
|
||||
cc_library(
|
||||
name = "feature-split-candidate",
|
||||
hdrs = ["learner/stochastic/stats/feature-split-candidate.h"],
|
||||
hdrs = ["learner/common/stats/feature-split-candidate.h"],
|
||||
deps = [
|
||||
":split-stats",
|
||||
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc",
|
||||
@ -447,7 +447,7 @@ cc_library(
|
||||
tf_cc_test(
|
||||
name = "node-stats_test",
|
||||
size = "small",
|
||||
srcs = ["learner/stochastic/stats/node-stats_test.cc"],
|
||||
srcs = ["learner/common/stats/node-stats_test.cc"],
|
||||
deps = [
|
||||
":node-stats",
|
||||
"//tensorflow/core:tensor_testutil",
|
||||
|
@ -13,10 +13,10 @@
|
||||
// limitations under the License.
|
||||
//
|
||||
// =============================================================================
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_FEATURE_SPLIT_CANDIDATE_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_FEATURE_SPLIT_CANDIDATE_H_
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_FEATURE_SPLIT_CANDIDATE_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_FEATURE_SPLIT_CANDIDATE_H_
|
||||
|
||||
#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/split-stats.h"
|
||||
#include "tensorflow/contrib/boosted_trees/lib/learner/common/stats/split-stats.h"
|
||||
#include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -58,4 +58,4 @@ struct FeatureSplitCandidate {
|
||||
} // namespace boosted_trees
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_FEATURE_SPLIT_CANDIDATE_H_
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_FEATURE_SPLIT_CANDIDATE_H_
|
@ -12,8 +12,8 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_GRADIENT_STATS_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_GRADIENT_STATS_H_
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_GRADIENT_STATS_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_GRADIENT_STATS_H_
|
||||
|
||||
#include <math.h>
|
||||
|
||||
@ -190,4 +190,4 @@ inline GradientStats operator-(const GradientStats& a, const GradientStats& b) {
|
||||
} // namespace boosted_trees
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_GRADIENT_STATS_H_
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_GRADIENT_STATS_H_
|
@ -12,12 +12,12 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_NODE_STATS_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_NODE_STATS_H_
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_NODE_STATS_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_NODE_STATS_H_
|
||||
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
#include "third_party/eigen3/Eigen/Eigenvalues"
|
||||
#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/gradient-stats.h"
|
||||
#include "tensorflow/contrib/boosted_trees/lib/learner/common/stats/gradient-stats.h"
|
||||
#include "tensorflow/contrib/boosted_trees/proto/learner.pb.h"
|
||||
#include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
@ -298,4 +298,4 @@ struct NodeStats {
|
||||
} // namespace boosted_trees
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_NODE_STATS_H_
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_NODE_STATS_H_
|
@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/node-stats.h"
|
||||
#include "tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h"
|
||||
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
@ -12,12 +12,12 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_SPLIT_STATS_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_SPLIT_STATS_H_
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_SPLIT_STATS_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_SPLIT_STATS_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/node-stats.h"
|
||||
#include "tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace boosted_trees {
|
||||
@ -81,4 +81,4 @@ struct SplitStats {
|
||||
} // namespace boosted_trees
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_SPLIT_STATS_H_
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_SPLIT_STATS_H_
|
@ -32,27 +32,41 @@ from tensorflow.python.platform import test
|
||||
class CrfTest(test.TestCase):
|
||||
|
||||
def testCrfSequenceScore(self):
|
||||
inputs = np.array(
|
||||
[[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32)
|
||||
tag_indices = np.array([1, 2, 1, 0], dtype=np.int32)
|
||||
transition_params = np.array(
|
||||
[[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
|
||||
sequence_lengths = np.array(3, dtype=np.int32)
|
||||
with self.test_session() as sess:
|
||||
sequence_score = crf.crf_sequence_score(
|
||||
inputs=array_ops.expand_dims(inputs, 0),
|
||||
tag_indices=array_ops.expand_dims(tag_indices, 0),
|
||||
sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
|
||||
transition_params=constant_op.constant(transition_params))
|
||||
sequence_score = array_ops.squeeze(sequence_score, [0])
|
||||
tf_sequence_score = sess.run(sequence_score)
|
||||
expected_unary_score = sum(inputs[i][tag_indices[i]]
|
||||
for i in range(sequence_lengths))
|
||||
expected_binary_score = sum(
|
||||
transition_params[tag_indices[i], tag_indices[i + 1]]
|
||||
for i in range(sequence_lengths - 1))
|
||||
expected_sequence_score = expected_unary_score + expected_binary_score
|
||||
self.assertAllClose(tf_sequence_score, expected_sequence_score)
|
||||
# Test both the length-1 and regular cases.
|
||||
sequence_lengths_list = [
|
||||
np.array(3, dtype=np.int32),
|
||||
np.array(1, dtype=np.int32)
|
||||
]
|
||||
inputs_list = [
|
||||
np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]],
|
||||
dtype=np.float32),
|
||||
np.array([[4, 5, -3]],
|
||||
dtype=np.float32),
|
||||
]
|
||||
tag_indices_list = [
|
||||
np.array([1, 2, 1, 0], dtype=np.int32),
|
||||
np.array([1], dtype=np.int32)
|
||||
]
|
||||
for sequence_lengths, inputs, tag_indices in zip(sequence_lengths_list,
|
||||
inputs_list,
|
||||
tag_indices_list):
|
||||
with self.test_session() as sess:
|
||||
sequence_score = crf.crf_sequence_score(
|
||||
inputs=array_ops.expand_dims(inputs, 0),
|
||||
tag_indices=array_ops.expand_dims(tag_indices, 0),
|
||||
sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
|
||||
transition_params=constant_op.constant(transition_params))
|
||||
sequence_score = array_ops.squeeze(sequence_score, [0])
|
||||
tf_sequence_score = sess.run(sequence_score)
|
||||
expected_unary_score = sum(inputs[i][tag_indices[i]]
|
||||
for i in range(sequence_lengths))
|
||||
expected_binary_score = sum(
|
||||
transition_params[tag_indices[i], tag_indices[i + 1]]
|
||||
for i in range(sequence_lengths - 1))
|
||||
expected_sequence_score = expected_unary_score + expected_binary_score
|
||||
self.assertAllClose(tf_sequence_score, expected_sequence_score)
|
||||
|
||||
def testCrfUnaryScore(self):
|
||||
inputs = np.array(
|
||||
@ -89,38 +103,54 @@ class CrfTest(test.TestCase):
|
||||
self.assertAllClose(tf_binary_score, expected_binary_score)
|
||||
|
||||
def testCrfLogNorm(self):
|
||||
inputs = np.array(
|
||||
[[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32)
|
||||
transition_params = np.array(
|
||||
[[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
|
||||
num_words = inputs.shape[0]
|
||||
num_tags = inputs.shape[1]
|
||||
sequence_lengths = np.array(3, dtype=np.int32)
|
||||
with self.test_session() as sess:
|
||||
all_sequence_scores = []
|
||||
# Test both the length-1 and regular cases.
|
||||
sequence_lengths_list = [
|
||||
np.array(3, dtype=np.int32),
|
||||
np.array(1, dtype=np.int32)
|
||||
]
|
||||
inputs_list = [
|
||||
np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]],
|
||||
dtype=np.float32),
|
||||
np.array([[3, -1, 3]],
|
||||
dtype=np.float32),
|
||||
]
|
||||
tag_indices_list = [
|
||||
np.array([1, 2, 1, 0], dtype=np.int32),
|
||||
np.array([2], dtype=np.int32)
|
||||
]
|
||||
|
||||
# Compare the dynamic program with brute force computation.
|
||||
for tag_indices in itertools.product(
|
||||
range(num_tags), repeat=sequence_lengths):
|
||||
tag_indices = list(tag_indices)
|
||||
tag_indices.extend([0] * (num_words - sequence_lengths))
|
||||
all_sequence_scores.append(
|
||||
crf.crf_sequence_score(
|
||||
inputs=array_ops.expand_dims(inputs, 0),
|
||||
tag_indices=array_ops.expand_dims(tag_indices, 0),
|
||||
sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
|
||||
transition_params=constant_op.constant(transition_params)))
|
||||
for sequence_lengths, inputs, tag_indices in zip(sequence_lengths_list,
|
||||
inputs_list,
|
||||
tag_indices_list):
|
||||
num_words = inputs.shape[0]
|
||||
num_tags = inputs.shape[1]
|
||||
with self.test_session() as sess:
|
||||
all_sequence_scores = []
|
||||
|
||||
brute_force_log_norm = math_ops.reduce_logsumexp(all_sequence_scores)
|
||||
log_norm = crf.crf_log_norm(
|
||||
inputs=array_ops.expand_dims(inputs, 0),
|
||||
sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
|
||||
transition_params=constant_op.constant(transition_params))
|
||||
log_norm = array_ops.squeeze(log_norm, [0])
|
||||
tf_brute_force_log_norm, tf_log_norm = sess.run(
|
||||
[brute_force_log_norm, log_norm])
|
||||
# Compare the dynamic program with brute force computation.
|
||||
for tag_indices in itertools.product(
|
||||
range(num_tags), repeat=sequence_lengths):
|
||||
tag_indices = list(tag_indices)
|
||||
tag_indices.extend([0] * (num_words - sequence_lengths))
|
||||
all_sequence_scores.append(
|
||||
crf.crf_sequence_score(
|
||||
inputs=array_ops.expand_dims(inputs, 0),
|
||||
tag_indices=array_ops.expand_dims(tag_indices, 0),
|
||||
sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
|
||||
transition_params=constant_op.constant(transition_params)))
|
||||
|
||||
self.assertAllClose(tf_log_norm, tf_brute_force_log_norm)
|
||||
brute_force_log_norm = math_ops.reduce_logsumexp(all_sequence_scores)
|
||||
log_norm = crf.crf_log_norm(
|
||||
inputs=array_ops.expand_dims(inputs, 0),
|
||||
sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
|
||||
transition_params=constant_op.constant(transition_params))
|
||||
log_norm = array_ops.squeeze(log_norm, [0])
|
||||
tf_brute_force_log_norm, tf_log_norm = sess.run(
|
||||
[brute_force_log_norm, log_norm])
|
||||
|
||||
self.assertAllClose(tf_log_norm, tf_brute_force_log_norm)
|
||||
|
||||
def testCrfLogLikelihood(self):
|
||||
inputs = np.array(
|
||||
@ -201,50 +231,66 @@ class CrfTest(test.TestCase):
|
||||
expected_max_sequence[:sequence_lengths])
|
||||
|
||||
def testCrfDecode(self):
|
||||
inputs = np.array(
|
||||
[[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32)
|
||||
transition_params = np.array(
|
||||
[[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
|
||||
sequence_lengths = np.array(3, dtype=np.int32)
|
||||
num_words = inputs.shape[0]
|
||||
num_tags = inputs.shape[1]
|
||||
# Test both the length-1 and regular cases.
|
||||
sequence_lengths_list = [
|
||||
np.array(3, dtype=np.int32),
|
||||
np.array(1, dtype=np.int32)
|
||||
]
|
||||
inputs_list = [
|
||||
np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]],
|
||||
dtype=np.float32),
|
||||
np.array([[-1, 2, 1]],
|
||||
dtype=np.float32),
|
||||
]
|
||||
tag_indices_list = [
|
||||
np.array([1, 2, 1, 0], dtype=np.int32),
|
||||
np.array([2], dtype=np.int32)
|
||||
]
|
||||
|
||||
with self.test_session() as sess:
|
||||
all_sequence_scores = []
|
||||
all_sequences = []
|
||||
for sequence_lengths, inputs, tag_indices in zip(sequence_lengths_list,
|
||||
inputs_list,
|
||||
tag_indices_list):
|
||||
num_words = inputs.shape[0]
|
||||
num_tags = inputs.shape[1]
|
||||
|
||||
# Compare the dynamic program with brute force computation.
|
||||
for tag_indices in itertools.product(
|
||||
range(num_tags), repeat=sequence_lengths):
|
||||
tag_indices = list(tag_indices)
|
||||
tag_indices.extend([0] * (num_words - sequence_lengths))
|
||||
all_sequences.append(tag_indices)
|
||||
sequence_score = crf.crf_sequence_score(
|
||||
inputs=array_ops.expand_dims(inputs, 0),
|
||||
tag_indices=array_ops.expand_dims(tag_indices, 0),
|
||||
sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
|
||||
transition_params=constant_op.constant(transition_params))
|
||||
sequence_score = array_ops.squeeze(sequence_score, [0])
|
||||
all_sequence_scores.append(sequence_score)
|
||||
with self.test_session() as sess:
|
||||
all_sequence_scores = []
|
||||
all_sequences = []
|
||||
|
||||
tf_all_sequence_scores = sess.run(all_sequence_scores)
|
||||
# Compare the dynamic program with brute force computation.
|
||||
for tag_indices in itertools.product(
|
||||
range(num_tags), repeat=sequence_lengths):
|
||||
tag_indices = list(tag_indices)
|
||||
tag_indices.extend([0] * (num_words - sequence_lengths))
|
||||
all_sequences.append(tag_indices)
|
||||
sequence_score = crf.crf_sequence_score(
|
||||
inputs=array_ops.expand_dims(inputs, 0),
|
||||
tag_indices=array_ops.expand_dims(tag_indices, 0),
|
||||
sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
|
||||
transition_params=constant_op.constant(transition_params))
|
||||
sequence_score = array_ops.squeeze(sequence_score, [0])
|
||||
all_sequence_scores.append(sequence_score)
|
||||
|
||||
expected_max_sequence_index = np.argmax(tf_all_sequence_scores)
|
||||
expected_max_sequence = all_sequences[expected_max_sequence_index]
|
||||
expected_max_score = tf_all_sequence_scores[expected_max_sequence_index]
|
||||
tf_all_sequence_scores = sess.run(all_sequence_scores)
|
||||
|
||||
actual_max_sequence, actual_max_score = crf.crf_decode(
|
||||
array_ops.expand_dims(inputs, 0),
|
||||
constant_op.constant(transition_params),
|
||||
array_ops.expand_dims(sequence_lengths, 0))
|
||||
actual_max_sequence = array_ops.squeeze(actual_max_sequence, [0])
|
||||
actual_max_score = array_ops.squeeze(actual_max_score, [0])
|
||||
tf_actual_max_sequence, tf_actual_max_score = sess.run(
|
||||
[actual_max_sequence, actual_max_score])
|
||||
expected_max_sequence_index = np.argmax(tf_all_sequence_scores)
|
||||
expected_max_sequence = all_sequences[expected_max_sequence_index]
|
||||
expected_max_score = tf_all_sequence_scores[expected_max_sequence_index]
|
||||
|
||||
self.assertAllClose(tf_actual_max_score, expected_max_score)
|
||||
self.assertEqual(list(tf_actual_max_sequence[:sequence_lengths]),
|
||||
expected_max_sequence[:sequence_lengths])
|
||||
actual_max_sequence, actual_max_score = crf.crf_decode(
|
||||
array_ops.expand_dims(inputs, 0),
|
||||
constant_op.constant(transition_params),
|
||||
array_ops.expand_dims(sequence_lengths, 0))
|
||||
actual_max_sequence = array_ops.squeeze(actual_max_sequence, [0])
|
||||
actual_max_score = array_ops.squeeze(actual_max_score, [0])
|
||||
tf_actual_max_sequence, tf_actual_max_score = sess.run(
|
||||
[actual_max_sequence, actual_max_score])
|
||||
|
||||
self.assertAllClose(tf_actual_max_score, expected_max_score)
|
||||
self.assertEqual(list(tf_actual_max_sequence[:sequence_lengths]),
|
||||
expected_max_sequence[:sequence_lengths])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -53,7 +53,9 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.layers import utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import rnn
|
||||
@ -101,12 +103,29 @@ def crf_sequence_score(inputs, tag_indices, sequence_lengths,
|
||||
Returns:
|
||||
sequence_scores: A [batch_size] vector of unnormalized sequence scores.
|
||||
"""
|
||||
# Compute the scores of the given tag sequence.
|
||||
unary_scores = crf_unary_score(tag_indices, sequence_lengths, inputs)
|
||||
binary_scores = crf_binary_score(tag_indices, sequence_lengths,
|
||||
transition_params)
|
||||
sequence_scores = unary_scores + binary_scores
|
||||
return sequence_scores
|
||||
# If max_seq_len is 1, we skip the score calculation and simply gather the
|
||||
# unary potentials of the single tag.
|
||||
def _single_seq_fn():
|
||||
batch_size = array_ops.shape(inputs, out_type=tag_indices.dtype)[0]
|
||||
example_inds = array_ops.reshape(
|
||||
math_ops.range(batch_size, dtype=tag_indices.dtype), [-1, 1])
|
||||
return array_ops.gather_nd(
|
||||
array_ops.squeeze(inputs, [1]),
|
||||
array_ops.concat([example_inds, tag_indices], axis=1))
|
||||
|
||||
def _multi_seq_fn():
|
||||
# Compute the scores of the given tag sequence.
|
||||
unary_scores = crf_unary_score(tag_indices, sequence_lengths, inputs)
|
||||
binary_scores = crf_binary_score(tag_indices, sequence_lengths,
|
||||
transition_params)
|
||||
sequence_scores = unary_scores + binary_scores
|
||||
return sequence_scores
|
||||
|
||||
return utils.smart_cond(
|
||||
pred=math_ops.equal(inputs.shape[1].value or array_ops.shape(inputs)[1],
|
||||
1),
|
||||
fn1=_single_seq_fn,
|
||||
fn2=_multi_seq_fn)
|
||||
|
||||
|
||||
def crf_log_norm(inputs, sequence_lengths, transition_params):
|
||||
@ -124,19 +143,32 @@ def crf_log_norm(inputs, sequence_lengths, transition_params):
|
||||
# algorithm.
|
||||
first_input = array_ops.slice(inputs, [0, 0, 0], [-1, 1, -1])
|
||||
first_input = array_ops.squeeze(first_input, [1])
|
||||
rest_of_input = array_ops.slice(inputs, [0, 1, 0], [-1, -1, -1])
|
||||
|
||||
# Compute the alpha values in the forward algorithm in order to get the
|
||||
# partition function.
|
||||
forward_cell = CrfForwardRnnCell(transition_params)
|
||||
_, alphas = rnn.dynamic_rnn(
|
||||
cell=forward_cell,
|
||||
inputs=rest_of_input,
|
||||
sequence_length=sequence_lengths - 1,
|
||||
initial_state=first_input,
|
||||
dtype=dtypes.float32)
|
||||
log_norm = math_ops.reduce_logsumexp(alphas, [1])
|
||||
return log_norm
|
||||
# If max_seq_len is 1, we skip the algorithm and simply reduce_logsumexp over
|
||||
# the "initial state" (the unary potentials).
|
||||
def _single_seq_fn():
|
||||
return math_ops.reduce_logsumexp(first_input, [1])
|
||||
|
||||
def _multi_seq_fn():
|
||||
"""Forward computation of alpha values."""
|
||||
rest_of_input = array_ops.slice(inputs, [0, 1, 0], [-1, -1, -1])
|
||||
|
||||
# Compute the alpha values in the forward algorithm in order to get the
|
||||
# partition function.
|
||||
forward_cell = CrfForwardRnnCell(transition_params)
|
||||
_, alphas = rnn.dynamic_rnn(
|
||||
cell=forward_cell,
|
||||
inputs=rest_of_input,
|
||||
sequence_length=sequence_lengths - 1,
|
||||
initial_state=first_input,
|
||||
dtype=dtypes.float32)
|
||||
log_norm = math_ops.reduce_logsumexp(alphas, [1])
|
||||
return log_norm
|
||||
|
||||
max_seq_len = array_ops.shape(inputs)[1]
|
||||
return control_flow_ops.cond(pred=math_ops.equal(max_seq_len, 1),
|
||||
true_fn=_single_seq_fn,
|
||||
false_fn=_multi_seq_fn)
|
||||
|
||||
|
||||
def crf_log_likelihood(inputs,
|
||||
@ -437,45 +469,64 @@ def crf_decode(potentials, transition_params, sequence_length):
|
||||
sequence_length: A [batch_size] vector of true sequence lengths.
|
||||
|
||||
Returns:
|
||||
decode_tags: A [batch_size, max_seq_len] tensor, with dtype tf.int32.
|
||||
decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`.
|
||||
Contains the highest scoring tag indices.
|
||||
best_score: A [batch_size] tensor, containing the score of decode_tags.
|
||||
best_score: A [batch_size] vector, containing the score of `decode_tags`.
|
||||
"""
|
||||
# For simplicity, in shape comments, denote:
|
||||
# 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output).
|
||||
num_tags = potentials.get_shape()[2].value
|
||||
# If max_seq_len is 1, we skip the algorithm and simply return the argmax tag
|
||||
# and the max activation.
|
||||
def _single_seq_fn():
|
||||
squeezed_potentials = array_ops.squeeze(potentials, [1])
|
||||
decode_tags = array_ops.expand_dims(
|
||||
math_ops.argmax(squeezed_potentials, axis=1), 1)
|
||||
best_score = math_ops.reduce_max(squeezed_potentials, axis=1)
|
||||
return math_ops.cast(decode_tags, dtype=dtypes.int32), best_score
|
||||
|
||||
# Computes forward decoding. Get last score and backpointers.
|
||||
crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params)
|
||||
initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1])
|
||||
initial_state = array_ops.squeeze(initial_state, axis=[1]) # [B, O]
|
||||
inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O]
|
||||
backpointers, last_score = rnn.dynamic_rnn(
|
||||
crf_fwd_cell,
|
||||
inputs=inputs,
|
||||
sequence_length=sequence_length - 1,
|
||||
initial_state=initial_state,
|
||||
time_major=False,
|
||||
dtype=dtypes.int32) # [B, T - 1, O], [B, O]
|
||||
backpointers = gen_array_ops.reverse_sequence(
|
||||
backpointers, sequence_length - 1, seq_dim=1) # [B, T-1, O]
|
||||
def _multi_seq_fn():
|
||||
"""Decoding of highest scoring sequence."""
|
||||
|
||||
# Computes backward decoding. Extract tag indices from backpointers.
|
||||
crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags)
|
||||
initial_state = math_ops.cast(math_ops.argmax(last_score, axis=1),
|
||||
dtype=dtypes.int32) # [B]
|
||||
initial_state = array_ops.expand_dims(initial_state, axis=-1) # [B, 1]
|
||||
decode_tags, _ = rnn.dynamic_rnn(
|
||||
crf_bwd_cell,
|
||||
inputs=backpointers,
|
||||
sequence_length=sequence_length - 1,
|
||||
initial_state=initial_state,
|
||||
time_major=False,
|
||||
dtype=dtypes.int32) # [B, T - 1, 1]
|
||||
decode_tags = array_ops.squeeze(decode_tags, axis=[2]) # [B, T - 1]
|
||||
decode_tags = array_ops.concat([initial_state, decode_tags], axis=1) # [B, T]
|
||||
decode_tags = gen_array_ops.reverse_sequence(
|
||||
decode_tags, sequence_length, seq_dim=1) # [B, T]
|
||||
# For simplicity, in shape comments, denote:
|
||||
# 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output).
|
||||
num_tags = potentials.get_shape()[2].value
|
||||
|
||||
best_score = math_ops.reduce_max(last_score, axis=1) # [B]
|
||||
return decode_tags, best_score
|
||||
# Computes forward decoding. Get last score and backpointers.
|
||||
crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params)
|
||||
initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1])
|
||||
initial_state = array_ops.squeeze(initial_state, axis=[1]) # [B, O]
|
||||
inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O]
|
||||
backpointers, last_score = rnn.dynamic_rnn( # [B, T - 1, O], [B, O]
|
||||
crf_fwd_cell,
|
||||
inputs=inputs,
|
||||
sequence_length=sequence_length - 1,
|
||||
initial_state=initial_state,
|
||||
time_major=False,
|
||||
dtype=dtypes.int32)
|
||||
backpointers = gen_array_ops.reverse_sequence( # [B, T - 1, O]
|
||||
backpointers, sequence_length - 1, seq_dim=1)
|
||||
|
||||
# Computes backward decoding. Extract tag indices from backpointers.
|
||||
crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags)
|
||||
initial_state = math_ops.cast(math_ops.argmax(last_score, axis=1), # [B]
|
||||
dtype=dtypes.int32)
|
||||
initial_state = array_ops.expand_dims(initial_state, axis=-1) # [B, 1]
|
||||
decode_tags, _ = rnn.dynamic_rnn( # [B, T - 1, 1]
|
||||
crf_bwd_cell,
|
||||
inputs=backpointers,
|
||||
sequence_length=sequence_length - 1,
|
||||
initial_state=initial_state,
|
||||
time_major=False,
|
||||
dtype=dtypes.int32)
|
||||
decode_tags = array_ops.squeeze(decode_tags, axis=[2]) # [B, T - 1]
|
||||
decode_tags = array_ops.concat([initial_state, decode_tags], # [B, T]
|
||||
axis=1)
|
||||
decode_tags = gen_array_ops.reverse_sequence( # [B, T]
|
||||
decode_tags, sequence_length, seq_dim=1)
|
||||
|
||||
best_score = math_ops.reduce_max(last_score, axis=1) # [B]
|
||||
return decode_tags, best_score
|
||||
|
||||
return utils.smart_cond(
|
||||
pred=math_ops.equal(
|
||||
potentials.shape[1].value or array_ops.shape(potentials)[1], 1),
|
||||
fn1=_single_seq_fn,
|
||||
fn2=_multi_seq_fn)
|
||||
|
@ -187,6 +187,7 @@ py_test(
|
||||
"manual", # b/67958761
|
||||
],
|
||||
deps = [
|
||||
":dataset_serialization_test",
|
||||
"//tensorflow/contrib/data/python/ops:dataset_ops",
|
||||
"//tensorflow/contrib/data/python/ops:transformation_ops",
|
||||
"//tensorflow/python:array_ops",
|
||||
|
@ -723,5 +723,41 @@ class BatchDatasetSerializationTest(
|
||||
num_outputs)
|
||||
|
||||
|
||||
class PaddedBatchDatasetSerializationTest(
|
||||
dataset_serialization_test_base.DatasetSerializationTestBase):
|
||||
|
||||
def testPaddedBatch(self):
|
||||
|
||||
def build_dataset(seq_lens):
|
||||
return dataset_ops.Dataset.from_tensor_slices(seq_lens).map(
|
||||
lambda x: array_ops.fill([x], x)).padded_batch(
|
||||
4, padded_shapes=[-1])
|
||||
|
||||
seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32)
|
||||
seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32)
|
||||
self.run_core_tests(lambda: build_dataset(seq_lens1),
|
||||
lambda: build_dataset(seq_lens2), 8)
|
||||
|
||||
def testPaddedBatchNonDefaultPadding(self):
|
||||
|
||||
def build_dataset(seq_lens):
|
||||
|
||||
def fill_tuple(x):
|
||||
filled = array_ops.fill([x], x)
|
||||
return (filled, string_ops.as_string(filled))
|
||||
|
||||
padded_shape = [-1]
|
||||
return dataset_ops.Dataset.from_tensor_slices(seq_lens).map(
|
||||
fill_tuple).padded_batch(
|
||||
4,
|
||||
padded_shapes=(padded_shape, padded_shape),
|
||||
padding_values=(-1, "<end>"))
|
||||
|
||||
seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32)
|
||||
seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32)
|
||||
self.run_core_tests(lambda: build_dataset(seq_lens1),
|
||||
lambda: build_dataset(seq_lens2), 8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -22,8 +22,10 @@ import math
|
||||
import threading
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from six.moves import zip_longest
|
||||
|
||||
from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
|
||||
from tensorflow.contrib.data.python.ops import dataset_ops
|
||||
from tensorflow.contrib.data.python.ops import interleave_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -209,6 +211,46 @@ class InterleaveDatasetTest(test.TestCase):
|
||||
sess.run(get_next)
|
||||
|
||||
|
||||
class InterleaveDatasetSeriazationTest(
|
||||
dataset_serialization_test_base.DatasetSerializationTestBase):
|
||||
|
||||
def _build_iterator_graph(self, input_values, cycle_length, block_length):
|
||||
repeat_count = 2
|
||||
return dataset_ops.Dataset.from_tensor_slices(input_values).repeat(
|
||||
repeat_count).interleave(
|
||||
lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x),
|
||||
cycle_length, block_length)
|
||||
|
||||
def testSerializationCore(self):
|
||||
input_values = np.array([4, 5, 6], dtype=np.int64)
|
||||
num_outputs = np.sum(input_values) * 2
|
||||
# cycle_length > 1, block_length > 1
|
||||
cycle_length = 2
|
||||
block_length = 3
|
||||
# pylint: disable=g-long-lambda
|
||||
self.run_core_tests(
|
||||
lambda: self._build_iterator_graph(
|
||||
input_values, cycle_length, block_length),
|
||||
lambda: self._build_iterator_graph(
|
||||
input_values, cycle_length * 2, block_length * 1),
|
||||
num_outputs)
|
||||
# cycle_length = 1
|
||||
cycle_length = 1
|
||||
block_length = 3
|
||||
self.run_core_tests(
|
||||
lambda: self._build_iterator_graph(
|
||||
input_values, cycle_length, block_length),
|
||||
None, num_outputs)
|
||||
# block_length = 1
|
||||
cycle_length = 2
|
||||
block_length = 1
|
||||
self.run_core_tests(
|
||||
lambda: self._build_iterator_graph(
|
||||
input_values, cycle_length, block_length),
|
||||
None, num_outputs)
|
||||
# pylint: enable=g-long-lambda
|
||||
|
||||
|
||||
class ParallelInterleaveDatasetTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -41,6 +41,7 @@ def try_import(name): # pylint: disable=invalid-name
|
||||
tf_logging.warning("Could not import %s: %s" % (name, str(e)))
|
||||
return module
|
||||
|
||||
|
||||
stats = try_import("scipy.stats")
|
||||
|
||||
|
||||
@ -62,9 +63,9 @@ class CauchyTest(test.TestCase):
|
||||
self.assertAllEqual(expected, scale_shape.eval())
|
||||
loc = array_ops.zeros(loc_shape)
|
||||
scale = array_ops.ones(scale_shape)
|
||||
self.assertAllEqual(
|
||||
expected,
|
||||
array_ops.shape(cauchy_lib.Cauchy(loc, scale).sample()).eval())
|
||||
self.assertAllEqual(expected,
|
||||
array_ops.shape(
|
||||
cauchy_lib.Cauchy(loc, scale).sample()).eval())
|
||||
|
||||
def _testParamStaticShapes(self, sample_shape, expected):
|
||||
param_shapes = cauchy_lib.Cauchy.param_static_shapes(sample_shape)
|
||||
@ -92,8 +93,7 @@ class CauchyTest(test.TestCase):
|
||||
cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale)
|
||||
|
||||
log_pdf = cauchy.log_prob(x)
|
||||
self.assertAllEqual(cauchy.batch_shape_tensor().eval(),
|
||||
log_pdf.shape)
|
||||
self.assertAllEqual(cauchy.batch_shape_tensor().eval(), log_pdf.shape)
|
||||
self.assertAllEqual(cauchy.batch_shape_tensor().eval(),
|
||||
log_pdf.eval().shape)
|
||||
self.assertAllEqual(cauchy.batch_shape, log_pdf.shape)
|
||||
@ -115,16 +115,15 @@ class CauchyTest(test.TestCase):
|
||||
with self.test_session():
|
||||
batch_size = 6
|
||||
loc = constant_op.constant([[3.0, -3.0]] * batch_size)
|
||||
scale = constant_op.constant([[np.sqrt(10.0), np.sqrt(15.0)]] *
|
||||
batch_size)
|
||||
scale = constant_op.constant(
|
||||
[[np.sqrt(10.0), np.sqrt(15.0)]] * batch_size)
|
||||
x = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T
|
||||
cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale)
|
||||
|
||||
log_pdf = cauchy.log_prob(x)
|
||||
log_pdf_values = log_pdf.eval()
|
||||
self.assertEqual(log_pdf.shape, (6, 2))
|
||||
self.assertAllEqual(cauchy.batch_shape_tensor().eval(),
|
||||
log_pdf.shape)
|
||||
self.assertAllEqual(cauchy.batch_shape_tensor().eval(), log_pdf.shape)
|
||||
self.assertAllEqual(cauchy.batch_shape_tensor().eval(),
|
||||
log_pdf.eval().shape)
|
||||
self.assertAllEqual(cauchy.batch_shape, log_pdf.shape)
|
||||
@ -248,8 +247,7 @@ class CauchyTest(test.TestCase):
|
||||
cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale)
|
||||
|
||||
entropy = cauchy.entropy()
|
||||
self.assertAllEqual(cauchy.batch_shape_tensor().eval(),
|
||||
entropy.shape)
|
||||
self.assertAllEqual(cauchy.batch_shape_tensor().eval(), entropy.shape)
|
||||
self.assertAllEqual(cauchy.batch_shape_tensor().eval(),
|
||||
entropy.eval().shape)
|
||||
self.assertAllEqual(cauchy.batch_shape, entropy.shape)
|
||||
@ -257,7 +255,7 @@ class CauchyTest(test.TestCase):
|
||||
|
||||
if not stats:
|
||||
return
|
||||
expected_entropy = stats.cauchy(loc, scale).entropy()
|
||||
expected_entropy = stats.cauchy(loc, scale[0]).entropy().reshape((1, 3))
|
||||
self.assertAllClose(expected_entropy, entropy.eval())
|
||||
|
||||
def testCauchyMode(self):
|
||||
@ -368,8 +366,8 @@ class CauchyTest(test.TestCase):
|
||||
self.assertAllEqual(expected_shape, samples.shape)
|
||||
self.assertAllEqual(expected_shape, sample_values.shape)
|
||||
|
||||
expected_shape = (tensor_shape.TensorShape(
|
||||
[n.eval()]).concatenate(cauchy.batch_shape))
|
||||
expected_shape = (
|
||||
tensor_shape.TensorShape([n.eval()]).concatenate(cauchy.batch_shape))
|
||||
|
||||
self.assertAllEqual(expected_shape, samples.shape)
|
||||
self.assertAllEqual(expected_shape, sample_values.shape)
|
||||
@ -385,18 +383,18 @@ class CauchyTest(test.TestCase):
|
||||
samples = cauchy.sample(n)
|
||||
sample_values = samples.eval()
|
||||
self.assertEqual(samples.shape, (100000, batch_size, 2))
|
||||
self.assertAllClose(np.median(sample_values[:, 0, 0]),
|
||||
loc_v[0], atol=1e-1)
|
||||
self.assertAllClose(np.median(sample_values[:, 0, 1]),
|
||||
loc_v[1], atol=1e-1)
|
||||
self.assertAllClose(
|
||||
np.median(sample_values[:, 0, 0]), loc_v[0], atol=1e-1)
|
||||
self.assertAllClose(
|
||||
np.median(sample_values[:, 0, 1]), loc_v[1], atol=1e-1)
|
||||
|
||||
expected_shape = tensor_shape.TensorShape([n.eval()]).concatenate(
|
||||
tensor_shape.TensorShape(cauchy.batch_shape_tensor().eval()))
|
||||
self.assertAllEqual(expected_shape, samples.shape)
|
||||
self.assertAllEqual(expected_shape, sample_values.shape)
|
||||
|
||||
expected_shape = (tensor_shape.TensorShape(
|
||||
[n.eval()]).concatenate(cauchy.batch_shape))
|
||||
expected_shape = (
|
||||
tensor_shape.TensorShape([n.eval()]).concatenate(cauchy.batch_shape))
|
||||
self.assertAllEqual(expected_shape, samples.shape)
|
||||
self.assertAllEqual(expected_shape, sample_values.shape)
|
||||
|
||||
@ -428,9 +426,12 @@ class CauchyTest(test.TestCase):
|
||||
self.assertEqual(cauchy.event_shape, ())
|
||||
self.assertAllEqual(cauchy.event_shape_tensor().eval(), [])
|
||||
self.assertAllEqual(
|
||||
sess.run(cauchy.batch_shape_tensor(),
|
||||
feed_dict={loc: 5.0,
|
||||
scale: [1.0, 2.0]}), [2])
|
||||
sess.run(
|
||||
cauchy.batch_shape_tensor(),
|
||||
feed_dict={
|
||||
loc: 5.0,
|
||||
scale: [1.0, 2.0]
|
||||
}), [2])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -30,7 +30,6 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops.distributions import distribution
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Cauchy",
|
||||
]
|
||||
@ -97,7 +96,7 @@ class Cauchy(distribution.Distribution):
|
||||
validate_args=False,
|
||||
allow_nan_stats=True,
|
||||
name="Cauchy"):
|
||||
"""Construct Cauchy distributions with loc and and scale `loc` and `scale`.
|
||||
"""Construct Cauchy distributions.
|
||||
|
||||
The parameters `loc` and `scale` must be shaped in a way that supports
|
||||
broadcasting (e.g. `loc + scale` is a valid operation).
|
||||
@ -121,8 +120,8 @@ class Cauchy(distribution.Distribution):
|
||||
"""
|
||||
parameters = locals()
|
||||
with ops.name_scope(name, values=[loc, scale]):
|
||||
with ops.control_dependencies([check_ops.assert_positive(scale)] if
|
||||
validate_args else []):
|
||||
with ops.control_dependencies([check_ops.assert_positive(scale)]
|
||||
if validate_args else []):
|
||||
self._loc = array_ops.identity(loc, name="loc")
|
||||
self._scale = array_ops.identity(scale, name="scale")
|
||||
check_ops.assert_same_float_dtype([self._loc, self._scale])
|
||||
@ -138,8 +137,8 @@ class Cauchy(distribution.Distribution):
|
||||
@staticmethod
|
||||
def _param_shapes(sample_shape):
|
||||
return dict(
|
||||
zip(("loc", "scale"), ([ops.convert_to_tensor(
|
||||
sample_shape, dtype=dtypes.int32)] * 2)))
|
||||
zip(("loc", "scale"),
|
||||
([ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)] * 2)))
|
||||
|
||||
@property
|
||||
def loc(self):
|
||||
@ -153,13 +152,10 @@ class Cauchy(distribution.Distribution):
|
||||
|
||||
def _batch_shape_tensor(self):
|
||||
return array_ops.broadcast_dynamic_shape(
|
||||
array_ops.shape(self.loc),
|
||||
array_ops.shape(self.scale))
|
||||
array_ops.shape(self.loc), array_ops.shape(self.scale))
|
||||
|
||||
def _batch_shape(self):
|
||||
return array_ops.broadcast_static_shape(
|
||||
self.loc.shape,
|
||||
self.scale.shape)
|
||||
return array_ops.broadcast_static_shape(self.loc.shape, self.scale.shape)
|
||||
|
||||
def _event_shape_tensor(self):
|
||||
return constant_op.constant([], dtype=dtypes.int32)
|
||||
|
@ -116,6 +116,7 @@ py_library(
|
||||
deps = [
|
||||
":clip_weights",
|
||||
":conditioning_utils",
|
||||
":tensor_pool",
|
||||
":virtual_batchnorm",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
@ -219,6 +220,37 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tensor_pool",
|
||||
srcs = [
|
||||
"python/features/python/tensor_pool.py",
|
||||
"python/features/python/tensor_pool_impl.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:data_flow_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:random_ops",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tensor_pool_test",
|
||||
srcs = ["python/features/python/tensor_pool_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":tensor_pool",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "virtual_batchnorm",
|
||||
srcs = [
|
||||
|
35
tensorflow/contrib/gan/python/features/python/tensor_pool.py
Normal file
35
tensorflow/contrib/gan/python/features/python/tensor_pool.py
Normal file
@ -0,0 +1,35 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""A tensor pool stores values from an input tensor and returns a stored one.
|
||||
|
||||
See the following papers for more details.
|
||||
1) `Learning from simulated and unsupervised images through adversarial
|
||||
training` (https://arxiv.org/abs/1612.07828).
|
||||
2) `Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial
|
||||
Networks` (https://arxiv.org/abs/1703.10593).
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.gan.python.features.python import tensor_pool_impl
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.contrib.gan.python.features.python.tensor_pool_impl import *
|
||||
# pylint: enable=wildcard-import
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
__all__ = tensor_pool_impl.__all__
|
||||
remove_undocumented(__name__, __all__)
|
@ -0,0 +1,118 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""A tensor pool stores values from an input tensor and returns a stored one.
|
||||
|
||||
We use this to keep a history of values created by a generator, such that
|
||||
a discriminator can randomly be trained on some older samples, not just the
|
||||
current one. This can help to not let the discriminator get too far ahead of the
|
||||
generator and also to keep the system from oscilating, if the discriminator
|
||||
forgets too fast what past samples from the generator looked like.
|
||||
|
||||
See the following papers for more details.
|
||||
1) `Learning from simulated and unsupervised images through adversarial
|
||||
training` (https://arxiv.org/abs/1612.07828).
|
||||
2) `Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial
|
||||
Networks` (https://arxiv.org/abs/1703.10593).
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
|
||||
__all__ = [
|
||||
'tensor_pool',
|
||||
]
|
||||
|
||||
|
||||
def tensor_pool(input_value,
|
||||
pool_size,
|
||||
pooling_probability=0.5,
|
||||
name='tensor_pool'):
|
||||
"""Queue storing input values and returning random previously stored ones.
|
||||
|
||||
Every time the returned `output_value` is evaluated, `input_value` is
|
||||
evaluated and its value either directly returned (with
|
||||
`1-pooling_probability`) or stored in the pool and a random one of the samples
|
||||
currently in the pool is popped and returned. As long as the pool in not fully
|
||||
filled, the input_value is always directly returned, as well as stored in the
|
||||
pool. Note during inference / testing, it may be appropriate to set
|
||||
`pool_size` = 0 or `pooling_probability` = 0.
|
||||
|
||||
Args:
|
||||
input_value: A `Tensor` from which to read values to be pooled.
|
||||
pool_size: An integer specifying the maximum size of the pool.
|
||||
pooling_probability: A float `Tensor` specifying the probability of getting
|
||||
a value from the pool, as opposed to just the current input.
|
||||
name: A string prefix for the name scope for all tensorflow ops.
|
||||
|
||||
Returns:
|
||||
A `Tensor` which is with given probability either the `input_value` or a
|
||||
randomly chosen sample that was previously inserted in the pool.
|
||||
|
||||
Raises:
|
||||
ValueError: If `pool_size` is negative.
|
||||
"""
|
||||
pool_size = int(pool_size)
|
||||
if pool_size < 0:
|
||||
raise ValueError('`pool_size` is negative.')
|
||||
elif pool_size == 0:
|
||||
return input_value
|
||||
|
||||
with ops.name_scope('{}_pool_queue'.format(name),
|
||||
values=[input_value, pooling_probability]):
|
||||
pool_queue = data_flow_ops.RandomShuffleQueue(
|
||||
capacity=pool_size,
|
||||
min_after_dequeue=0,
|
||||
dtypes=[input_value.dtype],
|
||||
shapes=None)
|
||||
|
||||
# In pseudeo code this code does the following:
|
||||
# if not pool_full:
|
||||
# enqueue(input_value)
|
||||
# return input_value
|
||||
# else
|
||||
# dequeue_value = dequeue_random_sample()
|
||||
# enqueue(input_value)
|
||||
# if rand() < pooling_probability:
|
||||
# return dequeue_value
|
||||
# else
|
||||
# return input_value
|
||||
|
||||
def _get_input_value_pooled():
|
||||
enqueue_op = pool_queue.enqueue(input_value)
|
||||
with ops.control_dependencies([enqueue_op]):
|
||||
return array_ops.identity(input_value)
|
||||
|
||||
def _get_random_pool_value_and_enqueue_input():
|
||||
dequeue_value = pool_queue.dequeue()
|
||||
with ops.control_dependencies([dequeue_value]):
|
||||
enqueue_op = pool_queue.enqueue(input_value)
|
||||
with ops.control_dependencies([enqueue_op]):
|
||||
prob = random_ops.random_uniform(
|
||||
(), dtype=dtypes.float32) < pooling_probability
|
||||
return control_flow_ops.cond(prob, lambda: dequeue_value,
|
||||
lambda: input_value)
|
||||
|
||||
output_value = control_flow_ops.cond(
|
||||
pool_queue.size() < pool_size, _get_input_value_pooled,
|
||||
_get_random_pool_value_and_enqueue_input)
|
||||
|
||||
return output_value
|
@ -0,0 +1,94 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tf.contrib.gan.python.features.tensor_pool."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.gan.python.features.python import tensor_pool_impl as tensor_pool
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class TensorPoolTest(test.TestCase):
|
||||
|
||||
def test_pool_unknown_input_shape(self):
|
||||
"""Checks that `input_value` can have unknown shape."""
|
||||
input_value = array_ops.placeholder(
|
||||
dtype=dtypes.int32, shape=[None, None, 3])
|
||||
output_value = tensor_pool.tensor_pool(input_value, pool_size=10)
|
||||
|
||||
with self.test_session(use_gpu=True) as session:
|
||||
for i in range(10):
|
||||
session.run(output_value, {input_value: [[[i] * 3]]})
|
||||
session.run(output_value, {input_value: [[[i] * 3] * 2]})
|
||||
session.run(output_value, {input_value: [[[i] * 3] * 5] * 2})
|
||||
|
||||
def test_pool_sequence(self):
|
||||
"""Checks that values are pooled and returned maximally twice."""
|
||||
input_value = array_ops.placeholder(dtype=dtypes.int32, shape=[])
|
||||
output_value = tensor_pool.tensor_pool(input_value, pool_size=10)
|
||||
|
||||
with self.test_session(use_gpu=True) as session:
|
||||
outs = []
|
||||
for i in range(50):
|
||||
out = session.run(output_value, {input_value: i})
|
||||
outs.append(out)
|
||||
self.assertLessEqual(out, i)
|
||||
|
||||
_, counts = np.unique(outs, return_counts=True)
|
||||
# Check that each value is returned maximally twice.
|
||||
self.assertTrue((counts <= 2).all())
|
||||
|
||||
def test_never_pool(self):
|
||||
"""Checks that setting `pooling_probability` to zero works."""
|
||||
input_value = array_ops.placeholder(dtype=dtypes.int32, shape=[])
|
||||
output_value = tensor_pool.tensor_pool(
|
||||
input_value, pool_size=10, pooling_probability=0.0)
|
||||
|
||||
with self.test_session(use_gpu=True) as session:
|
||||
for i in range(50):
|
||||
out = session.run(output_value, {input_value: i})
|
||||
self.assertEqual(out, i)
|
||||
|
||||
def test_pooling_probability(self):
|
||||
"""Checks that `pooling_probability` works."""
|
||||
input_value = array_ops.placeholder(dtype=dtypes.int32, shape=[])
|
||||
pool_size = 10
|
||||
pooling_probability = 0.2
|
||||
output_value = tensor_pool.tensor_pool(
|
||||
input_value,
|
||||
pool_size=pool_size,
|
||||
pooling_probability=pooling_probability)
|
||||
|
||||
with self.test_session(use_gpu=True) as session:
|
||||
not_pooled = 0
|
||||
total = 1000
|
||||
for i in range(total):
|
||||
out = session.run(output_value, {input_value: i})
|
||||
if out == i:
|
||||
not_pooled += 1
|
||||
self.assertAllClose(
|
||||
(not_pooled - pool_size) / (total - pool_size),
|
||||
1 - pooling_probability,
|
||||
atol=0.03)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -40,6 +40,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
|
||||
cov_ema_decay,
|
||||
damping,
|
||||
layer_collection,
|
||||
var_list=None,
|
||||
momentum=0.,
|
||||
momentum_type="regular",
|
||||
norm_constraint=None,
|
||||
@ -66,6 +67,9 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
|
||||
blocks, kronecker factors, and losses associated with the
|
||||
graph. The layer_collection cannot be modified after KfacOptimizer's
|
||||
initialization.
|
||||
var_list: Optional list or tuple of variables to train. Defaults to the
|
||||
list of variables collected in the graph under the key
|
||||
`GraphKeys.TRAINABLE_VARIABLES`.
|
||||
momentum: The momentum value for this optimizer. Only applies when
|
||||
momentum_type is 'regular' or 'adam'. (Default: 0)
|
||||
momentum_type: The type of momentum to use in this optimizer, one of
|
||||
@ -96,9 +100,9 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
|
||||
or 'adam'.
|
||||
"""
|
||||
|
||||
# We may consider determining the set of variables some other way, but for
|
||||
# now it's just all the trainable variables.
|
||||
variables = tf_variables.trainable_variables()
|
||||
variables = var_list
|
||||
if variables is None:
|
||||
variables = tf_variables.trainable_variables()
|
||||
|
||||
self._fisher_est = est.FisherEstimator(
|
||||
variables,
|
||||
@ -123,7 +127,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
|
||||
raise ValueError("Momentum must be unspecified if using a momentum_type "
|
||||
"other than 'regular' or 'adam'.")
|
||||
|
||||
self._momentum = ops.convert_to_tensor(momentum, name="momentum")
|
||||
self._momentum = momentum
|
||||
self._momentum_type = momentum_type
|
||||
self._norm_constraint = norm_constraint
|
||||
|
||||
@ -313,14 +317,17 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
|
||||
self._batch_size, dtype=fft_precon_grads[0].dtype)
|
||||
|
||||
# compute the entries of the 2x2 matrix
|
||||
m_11 = (_inner_product_list(fft_precon_grads, fft_precon_grads) / batch_size
|
||||
+ self.damping * _inner_product_list(precon_grads, precon_grads))
|
||||
m_11 = (
|
||||
_inner_product_list(fft_precon_grads, fft_precon_grads) / batch_size +
|
||||
self.damping * _inner_product_list(precon_grads, precon_grads))
|
||||
|
||||
m_21 = (_inner_product_list(fft_prev_updates, fft_precon_grads) / batch_size
|
||||
+ self.damping * _inner_product_list(prev_updates, precon_grads))
|
||||
m_21 = (
|
||||
_inner_product_list(fft_prev_updates, fft_precon_grads) / batch_size +
|
||||
self.damping * _inner_product_list(prev_updates, precon_grads))
|
||||
|
||||
m_22 = (_inner_product_list(fft_prev_updates, fft_prev_updates) / batch_size
|
||||
+ self.damping * _inner_product_list(prev_updates, prev_updates))
|
||||
m_22 = (
|
||||
_inner_product_list(fft_prev_updates, fft_prev_updates) / batch_size +
|
||||
self.damping * _inner_product_list(prev_updates, prev_updates))
|
||||
|
||||
def non_zero_prevupd_case():
|
||||
r"""Computes optimal (alpha, mu) given non-zero previous update.
|
||||
@ -406,8 +413,8 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
|
||||
grads = list(grad for (grad, _) in grads_and_vars)
|
||||
variables = list(var for (_, var) in grads_and_vars)
|
||||
# previous updates are the negative velocities (up to scaling by LR)
|
||||
prev_updates = list(-self._zeros_slot(var, "velocity", self._name)
|
||||
for var in variables)
|
||||
prev_updates = list(
|
||||
-self._zeros_slot(var, "velocity", self._name) for var in variables)
|
||||
|
||||
# Compute optimal velocity update parameters according to quadratic model
|
||||
alpha, mu, _ = self._compute_qmodel_hyperparams(
|
||||
|
@ -28,7 +28,6 @@ from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
|
||||
|
||||
# Method used for inverting matrices.
|
||||
POSDEF_INV_METHOD = "cholesky"
|
||||
|
||||
@ -202,9 +201,18 @@ def posdef_inv_cholesky(tensor, identity, damping):
|
||||
return linalg_ops.cholesky_solve(chol, identity)
|
||||
|
||||
|
||||
def posdef_inv_eig(tensor, identity, damping):
|
||||
"""Computes inverse(tensor + damping * identity) with eigendecomposition."""
|
||||
eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(
|
||||
tensor + damping * identity)
|
||||
return math_ops.matmul(
|
||||
eigenvectors / eigenvalues, eigenvectors, transpose_b=True)
|
||||
|
||||
|
||||
posdef_inv_funcs = {
|
||||
"matrix_inverse": posdef_inv_matrix_inverse,
|
||||
"cholesky": posdef_inv_cholesky,
|
||||
"eig": posdef_inv_eig,
|
||||
}
|
||||
|
||||
|
||||
@ -261,8 +269,8 @@ def fwd_gradients(ys, xs, grad_xs=None, stop_gradients=None):
|
||||
# generated by the first gradients_impl.gradients call.
|
||||
|
||||
us = [array_ops.zeros_like(y) + float("nan") for y in ys]
|
||||
dydxs = gradients_impl.gradients(ys, xs, grad_ys=us,
|
||||
stop_gradients=stop_gradients)
|
||||
dydxs = gradients_impl.gradients(
|
||||
ys, xs, grad_ys=us, stop_gradients=stop_gradients)
|
||||
|
||||
# Deal with strange types that gradients_impl.gradients returns but can't
|
||||
# deal with.
|
||||
@ -278,3 +286,6 @@ def fwd_gradients(ys, xs, grad_xs=None, stop_gradients=None):
|
||||
dysdx = gradients_impl.gradients(dydxs, us, grad_ys=grad_xs)
|
||||
|
||||
return dysdx
|
||||
|
||||
# TODO(b/69623235): Add a function for finding tensors that share gradients
|
||||
# to eliminate redundant fisher factor computations.
|
||||
|
@ -309,7 +309,6 @@ def _fused_batch_norm(inputs,
|
||||
new_shape = [-1, channels, 1, 1]
|
||||
inputs = array_ops.reshape(inputs, new_shape)
|
||||
inputs_shape = inputs.get_shape()
|
||||
dtype = inputs.dtype.base_dtype
|
||||
if data_format == DATA_FORMAT_NHWC:
|
||||
params_shape = inputs_shape[-1:]
|
||||
else:
|
||||
|
@ -1779,7 +1779,8 @@ class BatchNormTest(test.TestCase):
|
||||
dtype = dtypes.float32
|
||||
height, width = 3, 3
|
||||
with self.test_session():
|
||||
images = np.random.uniform(size=(5, height, width, 3)).astype(dtype.as_numpy_dtype)
|
||||
images = np.random.uniform(size=(5, height, width, 3)).astype(
|
||||
dtype.as_numpy_dtype)
|
||||
output = _layers.batch_norm(images, fused=fused)
|
||||
expected_name = ('BatchNorm/FusedBatchNorm' if fused else
|
||||
'BatchNorm/batchnorm')
|
||||
@ -2665,18 +2666,18 @@ class BatchNormTest(test.TestCase):
|
||||
# Test case for 11673
|
||||
with self.test_session() as sess:
|
||||
a_32 = array_ops.placeholder(dtypes.float32, shape=(10, 10, 10, 10))
|
||||
b_32 = _layers.batch_norm(a_32, center=False, data_format='NCHW',
|
||||
zero_debias_moving_mean=True)
|
||||
_layers.batch_norm(
|
||||
a_32, center=False, data_format='NCHW', zero_debias_moving_mean=True)
|
||||
a_16 = array_ops.placeholder(dtypes.float16, shape=(10, 10, 10, 10))
|
||||
b_16 = _layers.batch_norm(a_16, center=False, data_format='NCHW',
|
||||
zero_debias_moving_mean=True)
|
||||
_layers.batch_norm(
|
||||
a_16, center=False, data_format='NCHW', zero_debias_moving_mean=True)
|
||||
sess.run(variables_lib.global_variables_initializer())
|
||||
|
||||
def testVariablesAreFloat32(self):
|
||||
height, width = 3, 3
|
||||
with self.test_session():
|
||||
images = random_ops.random_uniform((5, height, width, 3),
|
||||
seed=1, dtype=dtypes.float16)
|
||||
images = random_ops.random_uniform(
|
||||
(5, height, width, 3), seed=1, dtype=dtypes.float16)
|
||||
_layers.batch_norm(images, scale=True)
|
||||
beta = variables.get_variables_by_name('beta')[0]
|
||||
gamma = variables.get_variables_by_name('gamma')[0]
|
||||
@ -2691,17 +2692,13 @@ class BatchNormTest(test.TestCase):
|
||||
channels = shape[1]
|
||||
images = np.arange(np.product(shape), dtype=dtype).reshape(shape)
|
||||
beta = init_ops.constant_initializer(
|
||||
np.arange(
|
||||
2, channels + 2, dtype=np.float32))
|
||||
np.arange(2, channels + 2, dtype=np.float32))
|
||||
gamma = init_ops.constant_initializer(
|
||||
np.arange(
|
||||
10, channels + 10, dtype=np.float32) * 2.0)
|
||||
np.arange(10, channels + 10, dtype=np.float32) * 2.0)
|
||||
mean = init_ops.constant_initializer(
|
||||
np.arange(
|
||||
3, channels + 3, dtype=np.float32) * 5.0)
|
||||
np.arange(3, channels + 3, dtype=np.float32) * 5.0)
|
||||
variance = init_ops.constant_initializer(
|
||||
np.arange(
|
||||
1, channels + 1, dtype=np.float32) * 4.0)
|
||||
np.arange(1, channels + 1, dtype=np.float32) * 4.0)
|
||||
output = _layers.batch_norm(
|
||||
images,
|
||||
fused=True,
|
||||
@ -2726,7 +2723,6 @@ class BatchNormTest(test.TestCase):
|
||||
res_16 = self._runFusedBatchNorm(shape, np.float16)
|
||||
self.assertAllClose(res_32, res_16, rtol=1e-3)
|
||||
|
||||
|
||||
def testAdjustmentCreated(self):
|
||||
# Tests that the adjustment is appropriately passed to and used by the core
|
||||
# BN layer.
|
||||
|
@ -28,7 +28,6 @@ import six
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
@ -369,10 +368,11 @@ class DataFeeder(object):
|
||||
if x_is_dict:
|
||||
num_samples = list(self._x.values())[0].shape[0]
|
||||
elif tensor_util.is_tensor(self._x):
|
||||
num_samples = self._x.shape[0].value # shape will be a Dimension, extract an int
|
||||
num_samples = self._x.shape[
|
||||
0].value # shape will be a Dimension, extract an int
|
||||
else:
|
||||
num_samples = self._x.shape[0]
|
||||
|
||||
|
||||
if self._shuffle:
|
||||
self.indices = self.random_state.permutation(num_samples)
|
||||
else:
|
||||
|
@ -251,8 +251,9 @@ class SdcaModel(object):
|
||||
|
||||
result_dense = 0.0
|
||||
for i in range(len(dense_variables)):
|
||||
result_dense += math_ops.matmul(
|
||||
dense_features[i], array_ops.expand_dims(dense_variables[i], -1))
|
||||
result_dense += math_ops.matmul(dense_features[i],
|
||||
array_ops.expand_dims(
|
||||
dense_variables[i], -1))
|
||||
|
||||
# Reshaping to allow shape inference at graph construction time.
|
||||
return array_ops.reshape(result_dense, [-1]) + result_sparse
|
||||
|
@ -164,8 +164,8 @@ def toco_convert(input_data,
|
||||
toco = _toco_flags_pb2.TocoFlags()
|
||||
toco.input_format = input_format
|
||||
toco.output_format = output_format
|
||||
toco.drop_control_dependency = drop_control_dependency
|
||||
model = _model_flags_pb2.ModelFlags()
|
||||
model.drop_control_dependency = drop_control_dependency
|
||||
toco.inference_type = inference_type
|
||||
for idx, input_tensor in enumerate(input_tensors):
|
||||
if input_tensor.dtype == _dtypes.float32:
|
||||
|
@ -40,6 +40,7 @@ from six import StringIO
|
||||
# TODO(aselle): Disable GPU for now
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
||||
|
||||
# pylint: disable=g-import-not-at-top
|
||||
import tensorflow as tf
|
||||
from google.protobuf import text_format
|
||||
# TODO(aselle): switch to TensorFlow's resource_loader
|
||||
@ -383,7 +384,7 @@ def make_zip_of_tests(zip_path,
|
||||
report["toco_log"] = ""
|
||||
tf.reset_default_graph()
|
||||
|
||||
with tf.device('/cpu:0'):
|
||||
with tf.device("/cpu:0"):
|
||||
try:
|
||||
inputs, outputs = make_graph(param_dict_real)
|
||||
except (tf.errors.UnimplementedError, tf.errors.InvalidArgumentError,
|
||||
|
@ -194,7 +194,6 @@ struct ParsedModelFlags {
|
||||
Arg<string> input_data_type;
|
||||
Arg<string> input_data_types;
|
||||
Arg<bool> variable_batch = Arg<bool>(false);
|
||||
Arg<bool> drop_control_dependency = Arg<bool>(false);
|
||||
Arg<toco::IntList> input_shape;
|
||||
Arg<toco::StringMapList> rnn_states;
|
||||
Arg<toco::StringMapList> model_checks;
|
||||
@ -224,6 +223,7 @@ struct ParsedTocoFlags {
|
||||
// Deprecated flags
|
||||
Arg<string> input_type;
|
||||
Arg<string> input_types;
|
||||
Arg<bool> drop_control_dependency = Arg<bool>(false);
|
||||
};
|
||||
|
||||
} // namespace toco
|
||||
|
@ -35,8 +35,11 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
using tensorflow::DT_BOOL;
|
||||
using tensorflow::DT_FLOAT;
|
||||
using tensorflow::DT_INT32;
|
||||
using tensorflow::DT_INT64;
|
||||
using tensorflow::DT_UINT8;
|
||||
using tensorflow::GraphDef;
|
||||
using tensorflow::TensorProto;
|
||||
|
||||
@ -1500,10 +1503,29 @@ void ConvertOperator(const Model& model, const Operator& src_op,
|
||||
}
|
||||
}
|
||||
|
||||
void AddPlaceholder(const string& name, GraphDef* tensorflow_graph) {
|
||||
void AddPlaceholder(const string& name, ArrayDataType type,
|
||||
GraphDef* tensorflow_graph) {
|
||||
auto* placeholder = tensorflow_graph->add_node();
|
||||
placeholder->set_op("Placeholder");
|
||||
(*placeholder->mutable_attr())["dtype"].set_type(DT_FLOAT);
|
||||
switch (type) {
|
||||
case ArrayDataType::kBool:
|
||||
(*placeholder->mutable_attr())["dtype"].set_type(DT_BOOL);
|
||||
break;
|
||||
case ArrayDataType::kFloat:
|
||||
(*placeholder->mutable_attr())["dtype"].set_type(DT_FLOAT);
|
||||
break;
|
||||
case ArrayDataType::kUint8:
|
||||
(*placeholder->mutable_attr())["dtype"].set_type(DT_UINT8);
|
||||
break;
|
||||
case ArrayDataType::kInt32:
|
||||
(*placeholder->mutable_attr())["dtype"].set_type(DT_INT32);
|
||||
break;
|
||||
case ArrayDataType::kInt64:
|
||||
(*placeholder->mutable_attr())["dtype"].set_type(DT_INT64);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unexpected data type in array \"" << name << "\"";
|
||||
}
|
||||
placeholder->set_name(name);
|
||||
}
|
||||
|
||||
@ -1531,7 +1553,9 @@ void AddPlaceholderForRNNState(const Model& model, const string& name, int size,
|
||||
void ExportTensorFlowGraphDefImplementation(const Model& model,
|
||||
GraphDef* tensorflow_graph) {
|
||||
for (const auto& input_array : model.flags.input_arrays()) {
|
||||
AddPlaceholder(input_array.name(), tensorflow_graph);
|
||||
AddPlaceholder(input_array.name(),
|
||||
model.arrays.at(input_array.name())->data_type,
|
||||
tensorflow_graph);
|
||||
}
|
||||
for (const auto& rnn_state : model.flags.rnn_states()) {
|
||||
AddPlaceholderForRNNState(model, rnn_state.state_array(), rnn_state.size(),
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -23,11 +23,19 @@ limitations under the License.
|
||||
|
||||
namespace toco {
|
||||
|
||||
std::unique_ptr<Model> ImportTensorFlowGraphDef(
|
||||
const ModelFlags& model_flags, const tensorflow::GraphDef& graph_def);
|
||||
struct TensorFlowImportFlags {
|
||||
// If true, control dependencies will be dropped immediately
|
||||
// during the import of the TensorFlow GraphDef.
|
||||
bool drop_control_dependency = false;
|
||||
};
|
||||
|
||||
std::unique_ptr<Model> ImportTensorFlowGraphDef(
|
||||
const ModelFlags& model_flags, const string& input_file_contents);
|
||||
const ModelFlags& model_flags, const TensorFlowImportFlags& tf_import_flags,
|
||||
const tensorflow::GraphDef& graph_def);
|
||||
|
||||
std::unique_ptr<Model> ImportTensorFlowGraphDef(
|
||||
const ModelFlags& model_flags, const TensorFlowImportFlags& tf_import_flags,
|
||||
const string& input_file_contents);
|
||||
|
||||
} // namespace toco
|
||||
|
||||
|
@ -112,13 +112,6 @@ bool ParseModelFlagsFromCommandLineFlags(
|
||||
"exclusive "
|
||||
"with the 'batch' field: at most one of these two fields can be "
|
||||
"set."),
|
||||
Flag(
|
||||
"drop_control_dependency",
|
||||
parsed_flags.drop_control_dependency.bind(),
|
||||
parsed_flags.drop_control_dependency.default_value(),
|
||||
"If true, ignore control dependency requirements in input TensorFlow "
|
||||
"GraphDef. Otherwise an error will be raised upon control dependency "
|
||||
"inputs."),
|
||||
Flag("rnn_states", parsed_flags.rnn_states.bind(),
|
||||
parsed_flags.rnn_states.default_value(), ""),
|
||||
Flag("model_checks", parsed_flags.model_checks.bind(),
|
||||
@ -316,7 +309,6 @@ void ReadModelFlagsFromCommandLineFlags(
|
||||
} while (false)
|
||||
|
||||
READ_MODEL_FLAG(variable_batch);
|
||||
READ_MODEL_FLAG(drop_control_dependency);
|
||||
|
||||
#undef READ_MODEL_FLAG
|
||||
|
||||
|
@ -138,8 +138,4 @@ message ModelFlags {
|
||||
optional int32 count_max = 3 [default = -1];
|
||||
}
|
||||
repeated ModelCheck model_checks = 14;
|
||||
|
||||
// If true, ignore control dependency requirements in input TensorFlow
|
||||
// GraphDef. Otherwise an error will be raised upon control dependency inputs.
|
||||
optional bool drop_control_dependency = 15;
|
||||
}
|
||||
|
@ -103,6 +103,13 @@ bool ParseTocoFlagsFromCommandLineFlags(
|
||||
parsed_flags.allow_custom_ops.default_value(),
|
||||
"If true, allow TOCO to create TF Lite Custom operators for all the"
|
||||
"unsupported Tensorflow ops."),
|
||||
Flag(
|
||||
"drop_control_dependency",
|
||||
parsed_flags.drop_control_dependency.bind(),
|
||||
parsed_flags.drop_control_dependency.default_value(),
|
||||
"If true, ignore control dependency requirements in input TensorFlow "
|
||||
"GraphDef. Otherwise an error will be raised upon control dependency "
|
||||
"inputs."),
|
||||
};
|
||||
bool asked_for_help =
|
||||
*argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
|
||||
@ -163,6 +170,7 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags,
|
||||
READ_TOCO_FLAG(drop_fake_quant, FlagRequirement::kNone);
|
||||
READ_TOCO_FLAG(reorder_across_fake_quant, FlagRequirement::kNone);
|
||||
READ_TOCO_FLAG(allow_custom_ops, FlagRequirement::kNone);
|
||||
READ_TOCO_FLAG(drop_control_dependency, FlagRequirement::kNone);
|
||||
|
||||
// Deprecated flag handling.
|
||||
if (parsed_toco_flags.input_type.specified()) {
|
||||
|
@ -36,7 +36,7 @@ enum FileFormat {
|
||||
// are not normally encoded in model files and in general may not be thought
|
||||
// of as properties of models, instead describing how models are to be
|
||||
// processed in the context of the present tooling job.
|
||||
// Next Id: 12
|
||||
// Next Id: 13
|
||||
message TocoFlags {
|
||||
// Input file format
|
||||
optional FileFormat input_format = 1;
|
||||
@ -128,4 +128,12 @@ message TocoFlags {
|
||||
// If true, allow TOCO to create TF Lite Custom operators for all the
|
||||
// unsupported Tensorflow ops.
|
||||
optional bool allow_custom_ops = 10;
|
||||
|
||||
// Applies only to the case when the input format is TENSORFLOW_GRAPHDEF.
|
||||
// If true, then control dependencies will be immediately dropped during
|
||||
// import.
|
||||
// If not set, the default behavior is as follows:
|
||||
// - Default to false if the output format is TENSORFLOW_GRAPHDEF.
|
||||
// - Default to true in all other cases.
|
||||
optional bool drop_control_dependency = 12;
|
||||
}
|
||||
|
@ -85,38 +85,57 @@ void MakeGeneralGraphTransformationsSet(
|
||||
transformations->Add(new MakeInitialDequantizeOperator);
|
||||
}
|
||||
|
||||
void SetArrayFinalDataTypes(const TocoFlags& toco_flags, Model* model) {
|
||||
const bool output_supports_only_float =
|
||||
toco_flags.output_format() == TENSORFLOW_GRAPHDEF;
|
||||
bool SupportsQuantization(FileFormat format) {
|
||||
return (format == GRAPHVIZ_DOT || format == TFLITE);
|
||||
;
|
||||
}
|
||||
|
||||
ArrayDataType specified_final_data_type = ArrayDataType::kNone;
|
||||
bool SupportsFusedActivationFunction(FileFormat format) {
|
||||
return (format == GRAPHVIZ_DOT || format == TFLITE);
|
||||
}
|
||||
|
||||
bool SupportsLstmCell(FileFormat format) {
|
||||
return (format == TENSORFLOW_GRAPHDEF || format == GRAPHVIZ_DOT);
|
||||
}
|
||||
|
||||
bool SupportsPreallocatedWorkspace(FileFormat format) {
|
||||
return (format == GRAPHVIZ_DOT || format == TFLITE);
|
||||
}
|
||||
|
||||
bool IsRealValued(toco::ArrayDataType type) {
|
||||
return static_cast<bool>(type == toco::ArrayDataType::kFloat ||
|
||||
type == toco::ArrayDataType::kUint8);
|
||||
}
|
||||
|
||||
void SetFinalDataTypeOnInputs(const TocoFlags& toco_flags, Model* model) {
|
||||
const FileFormat output_format = toco_flags.output_format();
|
||||
ArrayDataType type;
|
||||
if (toco_flags.has_inference_input_type()) {
|
||||
specified_final_data_type =
|
||||
ConvertIODataTypeToArrayDataType(toco_flags.inference_input_type());
|
||||
type = ConvertIODataTypeToArrayDataType(toco_flags.inference_input_type());
|
||||
} else if (toco_flags.has_inference_type()) {
|
||||
specified_final_data_type =
|
||||
ConvertIODataTypeToArrayDataType(toco_flags.inference_type());
|
||||
}
|
||||
ArrayDataType final_data_type = ArrayDataType::kNone;
|
||||
if (output_supports_only_float) {
|
||||
QCHECK(specified_final_data_type == ArrayDataType::kNone ||
|
||||
specified_final_data_type == ArrayDataType::kFloat);
|
||||
final_data_type = ArrayDataType::kFloat;
|
||||
type = ConvertIODataTypeToArrayDataType(toco_flags.inference_type());
|
||||
} else if (!SupportsQuantization(output_format)) {
|
||||
// Data type is implicitly float for non-quantized formats
|
||||
type = ArrayDataType::kFloat;
|
||||
} else {
|
||||
final_data_type = specified_final_data_type;
|
||||
// Nothing to do. Data types stay as-is.
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i = 0; i < model->flags.input_arrays_size(); i++) {
|
||||
auto* array = model->arrays[model->flags.input_arrays(i).name()].get();
|
||||
string const& array_name = model->flags.input_arrays(i).name();
|
||||
auto* array = model->arrays[array_name].get();
|
||||
// Note that the notion of changing data types only applies to real-numbers
|
||||
// arrays (see the documentation for inference_input_type).
|
||||
// TODO(benoitjacob) this is assuming that uint8 arrays are quantized,
|
||||
// i.e. represent real numbers by means of quantization parameters,
|
||||
// and not plain integer uint8 input arrays.
|
||||
const bool is_real_numbers = array->data_type == ArrayDataType::kFloat ||
|
||||
array->data_type == ArrayDataType::kUint8;
|
||||
if (is_real_numbers) {
|
||||
array->final_data_type = final_data_type;
|
||||
if (!IsRealValued(array->data_type)) {
|
||||
// Ignore non-real data types.
|
||||
continue;
|
||||
}
|
||||
|
||||
array->final_data_type = type;
|
||||
}
|
||||
}
|
||||
|
||||
@ -127,9 +146,16 @@ std::unique_ptr<Model> Import(const TocoFlags& toco_flags,
|
||||
const string& input_file_contents) {
|
||||
std::unique_ptr<Model> model;
|
||||
switch (toco_flags.input_format()) {
|
||||
case TENSORFLOW_GRAPHDEF:
|
||||
model = ImportTensorFlowGraphDef(model_flags, input_file_contents);
|
||||
case TENSORFLOW_GRAPHDEF: {
|
||||
TensorFlowImportFlags tf_import_flags;
|
||||
tf_import_flags.drop_control_dependency =
|
||||
toco_flags.has_drop_control_dependency()
|
||||
? toco_flags.drop_control_dependency()
|
||||
: (toco_flags.output_format() != TENSORFLOW_GRAPHDEF);
|
||||
model = ImportTensorFlowGraphDef(model_flags, tf_import_flags,
|
||||
input_file_contents);
|
||||
break;
|
||||
}
|
||||
case TFLITE:
|
||||
model = toco::tflite::Import(model_flags, input_file_contents);
|
||||
ResolveModelFlags(model_flags, model.get());
|
||||
@ -148,23 +174,21 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
|
||||
const FileFormat output_format = toco_flags.output_format();
|
||||
const IODataType inference_type = toco_flags.inference_type();
|
||||
|
||||
const bool output_is_tflite = output_format == TFLITE;
|
||||
const bool quantize_output =
|
||||
SupportsQuantization(output_format) && inference_type == QUANTIZED_UINT8;
|
||||
|
||||
const bool output_is_tflite_quantized =
|
||||
output_is_tflite && inference_type == QUANTIZED_UINT8;
|
||||
|
||||
if (output_is_tflite_quantized) {
|
||||
if (quantize_output) {
|
||||
QCHECK_NE(toco_flags.inference_input_type(), FLOAT)
|
||||
<< "Quantized inference is not allowed with float inputs.";
|
||||
}
|
||||
|
||||
SetArrayFinalDataTypes(toco_flags, model);
|
||||
SetFinalDataTypeOnInputs(toco_flags, model);
|
||||
|
||||
GraphTransformationsSet transformations;
|
||||
MakeGeneralGraphTransformationsSet(&transformations);
|
||||
auto* remove_trivial_reshape = new RemoveTrivialReshape;
|
||||
transformations.Add(remove_trivial_reshape);
|
||||
if (output_format == TFLITE) {
|
||||
if (SupportsFusedActivationFunction(output_format)) {
|
||||
transformations.Add(new FuseActivationFunctions);
|
||||
} else {
|
||||
transformations.Add(new UnfuseActivationFunctions);
|
||||
@ -183,25 +207,24 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
|
||||
// easy to pass a new toco flag. Once that is resolved on the DarwiNN
|
||||
// tests side, the special-casing of DarwiNN here can go away.
|
||||
// TODO(benoitjacob): so drop it when we can.
|
||||
if ((output_is_tflite_quantized &&
|
||||
toco_flags.reorder_across_fake_quant())) {
|
||||
if ((quantize_output && toco_flags.reorder_across_fake_quant())) {
|
||||
transformations.Add(new DropFakeQuant);
|
||||
}
|
||||
}
|
||||
transformations.Add(new ConvertPureConvToDepthwise);
|
||||
// TFLite export does not yet support fused LSTM cell.
|
||||
if (output_format == TENSORFLOW_GRAPHDEF) {
|
||||
if (SupportsLstmCell(output_format)) {
|
||||
transformations.Add(new IdentifyLstmCell);
|
||||
}
|
||||
transformations.Add(new ResolveConstantConcatenation);
|
||||
RunGraphTransformations(model, "general graph transformations",
|
||||
transformations);
|
||||
if (output_is_tflite_quantized) {
|
||||
if (quantize_output) {
|
||||
RunGraphTransformations(model, "pre-quantization graph transformations",
|
||||
{new HardcodeMinMax, new DropFakeQuant});
|
||||
}
|
||||
|
||||
if (output_is_tflite_quantized) {
|
||||
if (quantize_output) {
|
||||
if (toco_flags.has_default_ranges_min() &&
|
||||
toco_flags.has_default_ranges_max()) {
|
||||
UseDefaultMinMaxRangeValues(model, toco_flags.default_ranges_min(),
|
||||
@ -232,7 +255,7 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
|
||||
CheckUnsupportedOperations(*model);
|
||||
}
|
||||
|
||||
if (output_is_tflite) {
|
||||
if (SupportsPreallocatedWorkspace(output_format)) {
|
||||
AllocateTransientArrays(model, kDefaultTransientDataAlignment);
|
||||
LogDump(kLogLevelModelChanged, "AFTER ALLOCATION", *model);
|
||||
}
|
||||
|
@ -294,6 +294,7 @@ void LogArray(int log_level, const Model& model, const string& name) {
|
||||
VLOG(log_level) << "Array: " << name;
|
||||
switch (array.data_type) {
|
||||
case ArrayDataType::kNone:
|
||||
VLOG(log_level) << " Data type:";
|
||||
break;
|
||||
case ArrayDataType::kFloat:
|
||||
VLOG(log_level) << " Data type: kFloat";
|
||||
@ -309,6 +310,24 @@ void LogArray(int log_level, const Model& model, const string& name) {
|
||||
<< static_cast<int>(array.data_type) << ")";
|
||||
break;
|
||||
}
|
||||
switch (array.final_data_type) {
|
||||
case ArrayDataType::kNone:
|
||||
VLOG(log_level) << " Final type:";
|
||||
break;
|
||||
case ArrayDataType::kFloat:
|
||||
VLOG(log_level) << " Final type: kFloat";
|
||||
break;
|
||||
case ArrayDataType::kInt32:
|
||||
VLOG(log_level) << " Final type: kInt32";
|
||||
break;
|
||||
case ArrayDataType::kUint8:
|
||||
VLOG(log_level) << " Final type: kUint8";
|
||||
break;
|
||||
default:
|
||||
VLOG(log_level) << " Final type: other (numerical value: "
|
||||
<< static_cast<int>(array.data_type) << ")";
|
||||
break;
|
||||
}
|
||||
if (array.buffer) {
|
||||
VLOG(log_level) << " Constant Buffer";
|
||||
}
|
||||
@ -1016,7 +1035,6 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) {
|
||||
}
|
||||
|
||||
RESOLVE_MODEL_FLAG(variable_batch)
|
||||
RESOLVE_MODEL_FLAG(drop_control_dependency)
|
||||
|
||||
#undef RESOLVE_MODEL_FLAG
|
||||
|
||||
@ -1044,12 +1062,6 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) {
|
||||
"--output_arrays flag must be given on the command-line.";
|
||||
|
||||
for (const auto& input_array_proto : model->flags.input_arrays()) {
|
||||
QCHECK(!input_array_proto.shape().empty())
|
||||
<< "This model does not have shape defined for input array "
|
||||
<< input_array_proto.name()
|
||||
<< ", so one must be specified by a non-empty --input_shape "
|
||||
"command-line flag.";
|
||||
|
||||
auto& input_array = model->GetOrCreateArray(input_array_proto.name());
|
||||
if (input_array_proto.has_data_type()) {
|
||||
const ArrayDataType specified_type =
|
||||
@ -1072,6 +1084,14 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) {
|
||||
input_array.data_type = ArrayDataType::kFloat;
|
||||
}
|
||||
|
||||
if (!input_array.has_shape()) {
|
||||
QCHECK(!input_array_proto.shape().empty())
|
||||
<< "This model does not have shape defined for input array "
|
||||
<< input_array_proto.name()
|
||||
<< ", so one must be specified by a non-empty --input_shape "
|
||||
"command-line flag.";
|
||||
}
|
||||
|
||||
// Compare/merge the model->flags describing the input_shape with
|
||||
// the actual input array's shape.
|
||||
auto& input_array_dims = *input_array.mutable_shape()->mutable_dims();
|
||||
@ -1563,7 +1583,11 @@ void CheckFinalDataTypesSatisfied(const Model& model) {
|
||||
for (const auto& array_entry : model.arrays) {
|
||||
const auto& array = *array_entry.second;
|
||||
if (array.final_data_type != ArrayDataType::kNone) {
|
||||
CHECK(array.final_data_type == array.data_type);
|
||||
CHECK(array.final_data_type == array.data_type)
|
||||
<< "Array \"" << array_entry.first
|
||||
<< "\" has mis-matching actual and final data types ("
|
||||
<< static_cast<int>(array.data_type) << ","
|
||||
<< static_cast<int>(array.final_data_type) << ").";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -26,7 +26,6 @@ from tensorflow.contrib.opt.python.training.lazy_adam_optimizer import *
|
||||
from tensorflow.contrib.opt.python.training.moving_average_optimizer import *
|
||||
from tensorflow.contrib.opt.python.training.multitask_optimizer_wrapper import *
|
||||
from tensorflow.contrib.opt.python.training.nadam_optimizer import *
|
||||
from tensorflow.contrib.opt.python.training.nadam_optimizer import *
|
||||
from tensorflow.contrib.opt.python.training.powersign import *
|
||||
from tensorflow.contrib.opt.python.training.variable_clipping_optimizer import *
|
||||
# pylint: enable=wildcard-import
|
||||
@ -35,12 +34,18 @@ from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
|
||||
_allowed_symbols = [
|
||||
'PowerSignOptimizer', 'AddSignOptimizer'
|
||||
'PowerSignOptimizer',
|
||||
'AddSignOptimizer'
|
||||
'DelayCompensatedGradientDescentOptimizer',
|
||||
'DropStaleGradientOptimizer', 'ExternalOptimizerInterface',
|
||||
'LazyAdamOptimizer', 'NadamOptimizer', 'MovingAverageOptimizer',
|
||||
'ScipyOptimizerInterface', 'VariableClippingOptimizer',
|
||||
'MultitaskOptimizerWrapper', 'clip_gradients_by_global_norm',
|
||||
'DropStaleGradientOptimizer',
|
||||
'ExternalOptimizerInterface',
|
||||
'LazyAdamOptimizer',
|
||||
'NadamOptimizer',
|
||||
'MovingAverageOptimizer',
|
||||
'ScipyOptimizerInterface',
|
||||
'VariableClippingOptimizer',
|
||||
'MultitaskOptimizerWrapper',
|
||||
'clip_gradients_by_global_norm',
|
||||
]
|
||||
|
||||
remove_undocumented(__name__, _allowed_symbols)
|
||||
|
@ -12,9 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""An optimizer wrapper that ensures correct behaviour
|
||||
of stateful optimizers with multitask loss."""
|
||||
"""An optimizer wrapper for stateful optimizers with multitask loss."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -30,26 +28,27 @@ from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.training import optimizer
|
||||
|
||||
__all__ = ["MultitaskOptimizerWrapper",
|
||||
"clip_gradients_by_global_norm"]
|
||||
__all__ = ['MultitaskOptimizerWrapper', 'clip_gradients_by_global_norm']
|
||||
|
||||
|
||||
def _is_all_zeros(grad):
|
||||
all_zeros = math_ops.equal(math_ops.count_nonzero(grad), 0)
|
||||
return all_zeros
|
||||
|
||||
|
||||
def _get_wrapper(fn, opt):
|
||||
|
||||
def wrapper(self, grad, *args, **kwargs): # pylint: disable=unused-argument
|
||||
all_zeros = _is_all_zeros(grad)
|
||||
return control_flow_ops.cond(
|
||||
all_zeros,
|
||||
control_flow_ops.no_op,
|
||||
lambda: fn(grad, *args, **kwargs))
|
||||
return control_flow_ops.cond(all_zeros, control_flow_ops.no_op,
|
||||
lambda: fn(grad, *args, **kwargs))
|
||||
|
||||
wrapper = types.MethodType(wrapper, opt)
|
||||
return wrapper
|
||||
|
||||
|
||||
class MultitaskOptimizerWrapper(object):
|
||||
"""Optimizer wrapper that ensures that
|
||||
all-zero gradients don't affect the optimizer state.
|
||||
"""Optimizer wrapper making all-zero gradients harmless.
|
||||
|
||||
This might be useful when a multi-task loss is used,
|
||||
and some components of the loss might be
|
||||
@ -88,20 +87,20 @@ class MultitaskOptimizerWrapper(object):
|
||||
gradvars_clipped, global_step=batch)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
"""
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
opt: an instance of a class that implements tf.train.Optimizer.
|
||||
opt: an instance of a class that implements tf.train.Optimizer.
|
||||
"""
|
||||
if not isinstance(opt, optimizer.Optimizer):
|
||||
raise TypeError(
|
||||
"Supplied optimizer must be an instance of tf.train.Optimizer")
|
||||
'Supplied optimizer must be an instance of tf.train.Optimizer')
|
||||
self._opt = opt
|
||||
overriden_methods = ('_apply_dense',
|
||||
'_resource_apply_dense',
|
||||
'_apply_sparse',
|
||||
'_resource_apply_sparse')
|
||||
for name in overriden_methods:
|
||||
overridden_methods = ('_apply_dense', '_resource_apply_dense',
|
||||
'_apply_sparse', '_resource_apply_sparse')
|
||||
for name in overridden_methods:
|
||||
fn = getattr(self._opt, name)
|
||||
wrapper = _get_wrapper(fn, self._opt)
|
||||
setattr(self._opt, name, wrapper)
|
||||
@ -112,27 +111,30 @@ class MultitaskOptimizerWrapper(object):
|
||||
|
||||
def clip_gradients_by_global_norm(gradients_variables, clip_norm=20.):
|
||||
"""Clips gradients of a multitask loss by their global norm.
|
||||
|
||||
Ignores all-zero tensors when computing the global norm.
|
||||
|
||||
Args:
|
||||
gradients_variables: a list of pairs (gradient, variable).
|
||||
clip_norm: a float Tensor, the global norm to clip on. Default is 20.0.
|
||||
gradients_variables: a list of pairs (gradient, variable).
|
||||
clip_norm: a float Tensor, the global norm to clip on. Default is 20.0.
|
||||
|
||||
Returns:
|
||||
list: A list of pairs of the same type as gradients_variables,.
|
||||
fixed_global_norm: A 0-D (scalar) Tensor representing the global norm.
|
||||
list: A list of pairs of the same type as gradients_variables,.
|
||||
fixed_global_norm: A 0-D (scalar) Tensor representing the global norm.
|
||||
"""
|
||||
gradients, variables = six.moves.zip(*gradients_variables)
|
||||
|
||||
def _replace_nonexisting_grad(grad):
|
||||
if grad is None:
|
||||
return grad
|
||||
all_zeros = _is_all_zeros(grad)
|
||||
return control_flow_ops.cond(all_zeros,
|
||||
lambda: array_ops.zeros(
|
||||
[], dtype=dtypes.as_dtype(grad.dtype)),
|
||||
lambda: grad)
|
||||
return control_flow_ops.cond(
|
||||
all_zeros,
|
||||
lambda: array_ops.zeros([], dtype=dtypes.as_dtype(grad.dtype)),
|
||||
lambda: grad)
|
||||
|
||||
nonzero_gradients = [_replace_nonexisting_grad(g) for g in gradients]
|
||||
fixed_global_norm = clip_ops.global_norm(nonzero_gradients)
|
||||
gradients, _ = clip_ops.clip_by_global_norm(gradients, clip_norm,
|
||||
use_norm=fixed_global_norm)
|
||||
gradients, _ = clip_ops.clip_by_global_norm(
|
||||
gradients, clip_norm, use_norm=fixed_global_norm)
|
||||
return list(six.moves.zip(gradients, variables)), fixed_global_norm
|
||||
|
@ -18,6 +18,9 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.contrib.opt.python.training import multitask_optimizer_wrapper
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -25,13 +28,11 @@ from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import momentum
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
class MultitaskOptimizerWrapperTest(test.TestCase):
|
||||
"""Tests for the multitask optimizer wrapper.
|
||||
"""
|
||||
Tests for the multitask optimizer wrapper.
|
||||
"""
|
||||
|
||||
def testWrapper(self):
|
||||
with self.test_session():
|
||||
var0 = variables.Variable([1.0, 2.0], dtype=dtypes.float32)
|
||||
@ -39,12 +40,10 @@ class MultitaskOptimizerWrapperTest(test.TestCase):
|
||||
grads0 = constant_op.constant([0.1, 0.1], dtype=dtypes.float32)
|
||||
grads1 = constant_op.constant([0.01, 0.01], dtype=dtypes.float32)
|
||||
grads_allzero = constant_op.constant([0.0, 0.0], dtype=dtypes.float32)
|
||||
mom_opt_impl = momentum.MomentumOptimizer(
|
||||
learning_rate=2.0, momentum=0.9)
|
||||
mom_opt_impl = momentum.MomentumOptimizer(learning_rate=2.0, momentum=0.9)
|
||||
mom_opt = multitask_optimizer_wrapper.MultitaskOptimizerWrapper(
|
||||
mom_opt_impl)
|
||||
mom_update = mom_opt.apply_gradients(
|
||||
zip([grads0, grads1], [var0, var1]))
|
||||
mom_update = mom_opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||
mom_update_partial = mom_opt.apply_gradients(
|
||||
zip([grads_allzero, grads1], [var0, var1]))
|
||||
mom_update_no_action = mom_opt.apply_gradients(
|
||||
@ -63,14 +62,13 @@ class MultitaskOptimizerWrapperTest(test.TestCase):
|
||||
# Step 1: normal momentum update.
|
||||
self.evaluate(mom_update)
|
||||
# Check that the momentum accumulators have been updated.
|
||||
self.assertAllCloseAccordingToType(np.array([0.1, 0.1]),
|
||||
self.evaluate(slot0))
|
||||
self.assertAllCloseAccordingToType(np.array([0.01, 0.01]),
|
||||
self.evaluate(slot1))
|
||||
self.assertAllCloseAccordingToType(
|
||||
np.array([0.1, 0.1]), self.evaluate(slot0))
|
||||
self.assertAllCloseAccordingToType(
|
||||
np.array([0.01, 0.01]), self.evaluate(slot1))
|
||||
# Check that the parameters have been updated.
|
||||
self.assertAllCloseAccordingToType(
|
||||
np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]),
|
||||
self.evaluate(var0))
|
||||
np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]), self.evaluate(var0))
|
||||
self.assertAllCloseAccordingToType(
|
||||
np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]),
|
||||
self.evaluate(var1))
|
||||
@ -78,8 +76,8 @@ class MultitaskOptimizerWrapperTest(test.TestCase):
|
||||
# Step 2: momentum update that changes only slot1 but not slot0.
|
||||
self.evaluate(mom_update_partial)
|
||||
# Check that only the relevant momentum accumulator has been updated.
|
||||
self.assertAllCloseAccordingToType(np.array([0.1, 0.1]),
|
||||
self.evaluate(slot0))
|
||||
self.assertAllCloseAccordingToType(
|
||||
np.array([0.1, 0.1]), self.evaluate(slot0))
|
||||
self.assertAllCloseAccordingToType(
|
||||
np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]),
|
||||
self.evaluate(slot1))
|
||||
@ -87,8 +85,8 @@ class MultitaskOptimizerWrapperTest(test.TestCase):
|
||||
# Step 3: momentum update that does not change anything.
|
||||
self.evaluate(mom_update_no_action)
|
||||
# Check that the momentum accumulators have *NOT* been updated.
|
||||
self.assertAllCloseAccordingToType(np.array([0.1, 0.1]),
|
||||
self.evaluate(slot0))
|
||||
self.assertAllCloseAccordingToType(
|
||||
np.array([0.1, 0.1]), self.evaluate(slot0))
|
||||
self.assertAllCloseAccordingToType(
|
||||
np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]),
|
||||
self.evaluate(slot1))
|
||||
@ -105,8 +103,9 @@ class MultitaskOptimizerWrapperTest(test.TestCase):
|
||||
grads3 = None
|
||||
varlist = [var0, var1, var2, var3]
|
||||
gradients = [grads0, grads1, grads2, grads3]
|
||||
clipped_gradvars, global_norm = multitask_optimizer_wrapper.clip_gradients_by_global_norm(
|
||||
six.moves.zip(gradients, varlist), clip_norm=1.0)
|
||||
clipped_gradvars, global_norm = (
|
||||
multitask_optimizer_wrapper.clip_gradients_by_global_norm(
|
||||
six.moves.zip(gradients, varlist), clip_norm=1.0))
|
||||
clipped_grads = list(six.moves.zip(*clipped_gradvars))[0]
|
||||
reference_global_norm = np.sqrt(np.sum(np.square([10.0, 15.0, 0.0, 5.0])))
|
||||
self.assertAllCloseAccordingToType(
|
||||
@ -115,5 +114,6 @@ class MultitaskOptimizerWrapperTest(test.TestCase):
|
||||
self.evaluate(clipped_grads[2]), np.array([0., 0.]))
|
||||
self.assertEqual(clipped_grads[3], None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -24,6 +24,7 @@ import numpy as np
|
||||
|
||||
from tensorflow.contrib import rnn as contrib_rnn
|
||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell
|
||||
from tensorflow.contrib.rnn.python.ops import rnn_cell as contrib_rnn_cell
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -374,19 +375,20 @@ class RNNCellTest(test.TestCase):
|
||||
h = array_ops.zeros([batch_size, num_proj])
|
||||
state = rnn_cell_impl.LSTMStateTuple(c, h)
|
||||
cell = contrib_rnn_cell.LayerNormLSTMCell(
|
||||
num_units=num_units,
|
||||
num_proj=num_proj,
|
||||
forget_bias=1.0,
|
||||
layer_norm=True,
|
||||
norm_gain=1.0,
|
||||
norm_shift=0.0)
|
||||
num_units=num_units,
|
||||
num_proj=num_proj,
|
||||
forget_bias=1.0,
|
||||
layer_norm=True,
|
||||
norm_gain=1.0,
|
||||
norm_shift=0.0)
|
||||
g, out_m = cell(x, state)
|
||||
sess.run([variables_lib.global_variables_initializer()])
|
||||
res = sess.run([g, out_m], {
|
||||
x.name: np.ones((batch_size, input_size)),
|
||||
c.name: 0.1 * np.ones((batch_size, num_units)),
|
||||
h.name: 0.1 * np.ones((batch_size, num_proj))
|
||||
})
|
||||
res = sess.run(
|
||||
[g, out_m], {
|
||||
x.name: np.ones((batch_size, input_size)),
|
||||
c.name: 0.1 * np.ones((batch_size, num_units)),
|
||||
h.name: 0.1 * np.ones((batch_size, num_proj))
|
||||
})
|
||||
self.assertEqual(len(res), 2)
|
||||
# The numbers in results were not calculated, this is mostly just a
|
||||
# smoke test.
|
||||
@ -396,9 +398,9 @@ class RNNCellTest(test.TestCase):
|
||||
# Different inputs so different outputs and states
|
||||
for i in range(1, batch_size):
|
||||
self.assertTrue(
|
||||
float(np.linalg.norm((res[0][0, :] - res[0][i, :]))) < 1e-6)
|
||||
float(np.linalg.norm((res[0][0, :] - res[0][i, :]))) < 1e-6)
|
||||
self.assertTrue(
|
||||
float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) < 1e-6)
|
||||
float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) < 1e-6)
|
||||
|
||||
def testOutputProjectionWrapper(self):
|
||||
with self.test_session() as sess:
|
||||
|
@ -996,26 +996,19 @@ class RNNCellTest(test.TestCase):
|
||||
output, state = cell(x, hidden)
|
||||
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
res = sess.run([output, state], {
|
||||
hidden[0].name:
|
||||
np.array([[[[[1.],[1.]],
|
||||
[[1.],[1.]]],
|
||||
[[[1.],[1.]],
|
||||
[[1.],[1.]]]],
|
||||
[[[[2.],[2.]],
|
||||
[[2.],[2.]]],
|
||||
[[[2.],[2.]],
|
||||
[[2.],[2.]]]]]),
|
||||
x.name:
|
||||
np.array([[[[[1.],[1.]],
|
||||
[[1.],[1.]]],
|
||||
[[[1.],[1.]],
|
||||
[[1.],[1.]]]],
|
||||
[[[[2.],[2.]],
|
||||
[[2.],[2.]]],
|
||||
[[[2.],[2.]],
|
||||
[[2.],[2.]]]]])
|
||||
})
|
||||
res = sess.run(
|
||||
[output, state], {
|
||||
hidden[0].name:
|
||||
np.array([[[[[1.], [1.]], [[1.], [1.]]], [[[1.], [1.]], [[
|
||||
1.
|
||||
], [1.]]]], [[[[2.], [2.]], [[2.], [2.]]],
|
||||
[[[2.], [2.]], [[2.], [2.]]]]]),
|
||||
x.name:
|
||||
np.array([[[[[1.], [1.]], [[1.], [1.]]], [[[1.], [1.]], [[
|
||||
1.
|
||||
], [1.]]]], [[[[2.], [2.]], [[2.], [2.]]], [[[2.], [2.]],
|
||||
[[2.], [2.]]]]])
|
||||
})
|
||||
# This is a smoke test, making sure expected values are unchanged.
|
||||
self.assertEqual(len(res), 2)
|
||||
self.assertAllClose(res[0], res[1].h)
|
||||
@ -1276,10 +1269,8 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
|
||||
self.assertAllClose(res[2].c, expected_c1, 1e-5)
|
||||
self.assertAllClose(res[2].h, expected_h1, 1e-5)
|
||||
|
||||
|
||||
def testBasicLSTMCellWithStateTupleLayerNorm(self):
|
||||
"""The results of LSTMCell and LayerNormBasicLSTMCell
|
||||
should be same. """
|
||||
"""The results of LSTMCell and LayerNormBasicLSTMCell should be the same."""
|
||||
with self.test_session() as sess:
|
||||
with variable_scope.variable_scope(
|
||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||
@ -1290,21 +1281,21 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
|
||||
c1 = array_ops.zeros([1, 2])
|
||||
h1 = array_ops.zeros([1, 2])
|
||||
state1 = rnn_cell_impl.LSTMStateTuple(c1, h1)
|
||||
cell = rnn_cell_impl.MultiRNNCell(
|
||||
[contrib_rnn_cell.LayerNormLSTMCell(
|
||||
2,
|
||||
layer_norm=True,
|
||||
norm_gain=1.0,
|
||||
norm_shift=0.0) for _ in range(2)])
|
||||
cell = rnn_cell_impl.MultiRNNCell([
|
||||
contrib_rnn_cell.LayerNormLSTMCell(
|
||||
2, layer_norm=True, norm_gain=1.0, norm_shift=0.0)
|
||||
for _ in range(2)
|
||||
])
|
||||
h, (s0, s1) = cell(x, (state0, state1))
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
res = sess.run([h, s0, s1], {
|
||||
x.name: np.array([[1., 1.]]),
|
||||
c0.name: 0.1 * np.asarray([[0, 1]]),
|
||||
h0.name: 0.1 * np.asarray([[2, 3]]),
|
||||
c1.name: 0.1 * np.asarray([[4, 5]]),
|
||||
h1.name: 0.1 * np.asarray([[6, 7]]),
|
||||
})
|
||||
res = sess.run(
|
||||
[h, s0, s1], {
|
||||
x.name: np.array([[1., 1.]]),
|
||||
c0.name: 0.1 * np.asarray([[0, 1]]),
|
||||
h0.name: 0.1 * np.asarray([[2, 3]]),
|
||||
c1.name: 0.1 * np.asarray([[4, 5]]),
|
||||
h1.name: 0.1 * np.asarray([[6, 7]]),
|
||||
})
|
||||
|
||||
expected_h = np.array([[-0.38079708, 0.38079708]])
|
||||
expected_h0 = np.array([[-0.38079708, 0.38079708]])
|
||||
|
@ -115,7 +115,6 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
|
||||
|
||||
The class uses optional peep-hole connections, and an optional projection
|
||||
layer.
|
||||
|
||||
Layer normalization implementation is based on:
|
||||
|
||||
https://arxiv.org/abs/1607.06450.
|
||||
@ -124,15 +123,24 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
|
||||
Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton
|
||||
|
||||
and is applied before the internal nonlinearities.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, num_units, use_peepholes=False,
|
||||
initializer=None, num_proj=None, proj_clip=None,
|
||||
num_unit_shards=1, num_proj_shards=1,
|
||||
forget_bias=1.0, state_is_tuple=True,
|
||||
activation=math_ops.tanh, reuse=None,
|
||||
layer_norm=False, norm_gain=1.0, norm_shift=0.0):
|
||||
def __init__(self,
|
||||
num_units,
|
||||
use_peepholes=False,
|
||||
initializer=None,
|
||||
num_proj=None,
|
||||
proj_clip=None,
|
||||
num_unit_shards=1,
|
||||
num_proj_shards=1,
|
||||
forget_bias=1.0,
|
||||
state_is_tuple=True,
|
||||
activation=math_ops.tanh,
|
||||
reuse=None,
|
||||
layer_norm=False,
|
||||
norm_gain=1.0,
|
||||
norm_shift=0.0):
|
||||
"""Initialize the parameters for an LSTM cell.
|
||||
|
||||
Args:
|
||||
@ -164,8 +172,6 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
|
||||
`layer_norm` has been set to `False`, this argument will be ignored.
|
||||
norm_shift: float, The layer normalization shift initial value. If
|
||||
`layer_norm` has been set to `False`, this argument will be ignored.
|
||||
|
||||
|
||||
"""
|
||||
super(CoupledInputForgetGateLSTMCell, self).__init__(_reuse=reuse)
|
||||
if not state_is_tuple:
|
||||
@ -2049,8 +2055,8 @@ class ConvLSTMCell(rnn_cell_impl.RNNCell):
|
||||
if self._skip_connection:
|
||||
self._total_output_channels += self._input_shape[-1]
|
||||
|
||||
state_size = tensor_shape.TensorShape(self._input_shape[:-1]
|
||||
+ [self._output_channels])
|
||||
state_size = tensor_shape.TensorShape(
|
||||
self._input_shape[:-1] + [self._output_channels])
|
||||
self._state_size = rnn_cell_impl.LSTMStateTuple(state_size, state_size)
|
||||
self._output_size = tensor_shape.TensorShape(self._input_shape[:-1]
|
||||
+ [self._total_output_channels])
|
||||
@ -2110,11 +2116,8 @@ class Conv3DLSTMCell(ConvLSTMCell):
|
||||
"""Construct Conv3DLSTM. See `ConvLSTMCell` for more details."""
|
||||
super(Conv3DLSTMCell, self).__init__(conv_ndims=3, **kwargs)
|
||||
|
||||
def _conv(args,
|
||||
filter_size,
|
||||
num_features,
|
||||
bias,
|
||||
bias_start=0.0):
|
||||
|
||||
def _conv(args, filter_size, num_features, bias, bias_start=0.0):
|
||||
"""convolution:
|
||||
Args:
|
||||
args: a Tensor or a list of Tensors of dimension 3D, 4D or 5D,
|
||||
@ -2391,12 +2394,19 @@ class LayerNormLSTMCell(rnn_cell_impl.RNNCell):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, num_units,
|
||||
use_peepholes=False, cell_clip=None,
|
||||
initializer=None, num_proj=None, proj_clip=None,
|
||||
def __init__(self,
|
||||
num_units,
|
||||
use_peepholes=False,
|
||||
cell_clip=None,
|
||||
initializer=None,
|
||||
num_proj=None,
|
||||
proj_clip=None,
|
||||
forget_bias=1.0,
|
||||
activation=None, layer_norm=False,
|
||||
norm_gain=1.0, norm_shift=0.0, reuse=None):
|
||||
activation=None,
|
||||
layer_norm=False,
|
||||
norm_gain=1.0,
|
||||
norm_shift=0.0,
|
||||
reuse=None):
|
||||
"""Initialize the parameters for an LSTM cell.
|
||||
|
||||
Args:
|
||||
@ -2457,7 +2467,6 @@ class LayerNormLSTMCell(rnn_cell_impl.RNNCell):
|
||||
def output_size(self):
|
||||
return self._output_size
|
||||
|
||||
|
||||
def _linear(self,
|
||||
args,
|
||||
output_size,
|
||||
@ -2507,9 +2516,9 @@ class LayerNormLSTMCell(rnn_cell_impl.RNNCell):
|
||||
scope = vs.get_variable_scope()
|
||||
with vs.variable_scope(scope) as outer_scope:
|
||||
weights = vs.get_variable(
|
||||
"kernel", [total_arg_size, output_size],
|
||||
dtype=dtype,
|
||||
initializer=kernel_initializer)
|
||||
"kernel", [total_arg_size, output_size],
|
||||
dtype=dtype,
|
||||
initializer=kernel_initializer)
|
||||
if len(args) == 1:
|
||||
res = math_ops.matmul(args[0], weights)
|
||||
else:
|
||||
@ -2521,9 +2530,7 @@ class LayerNormLSTMCell(rnn_cell_impl.RNNCell):
|
||||
if bias_initializer is None:
|
||||
bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype)
|
||||
biases = vs.get_variable(
|
||||
"bias", [output_size],
|
||||
dtype=dtype,
|
||||
initializer=bias_initializer)
|
||||
"bias", [output_size], dtype=dtype, initializer=bias_initializer)
|
||||
|
||||
if not layer_norm:
|
||||
res = nn_ops.bias_add(res, biases)
|
||||
@ -2554,7 +2561,6 @@ class LayerNormLSTMCell(rnn_cell_impl.RNNCell):
|
||||
ValueError: If input size cannot be inferred from inputs via
|
||||
static shape inference.
|
||||
"""
|
||||
num_proj = self._num_units if self._num_proj is None else self._num_proj
|
||||
sigmoid = math_ops.sigmoid
|
||||
|
||||
(c_prev, m_prev) = state
|
||||
@ -2567,10 +2573,14 @@ class LayerNormLSTMCell(rnn_cell_impl.RNNCell):
|
||||
with vs.variable_scope(scope, initializer=self._initializer) as unit_scope:
|
||||
|
||||
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
|
||||
lstm_matrix = self._linear([inputs, m_prev], 4 * self._num_units, bias=True,
|
||||
bias_initializer=None, layer_norm=self._layer_norm)
|
||||
lstm_matrix = self._linear(
|
||||
[inputs, m_prev],
|
||||
4 * self._num_units,
|
||||
bias=True,
|
||||
bias_initializer=None,
|
||||
layer_norm=self._layer_norm)
|
||||
i, j, f, o = array_ops.split(
|
||||
value=lstm_matrix, num_or_size_splits=4, axis=1)
|
||||
value=lstm_matrix, num_or_size_splits=4, axis=1)
|
||||
|
||||
if self._layer_norm:
|
||||
i = _norm(self._norm_gain, self._norm_shift, i, "input")
|
||||
@ -2580,20 +2590,22 @@ class LayerNormLSTMCell(rnn_cell_impl.RNNCell):
|
||||
|
||||
# Diagonal connections
|
||||
if self._use_peepholes:
|
||||
with vs.variable_scope(unit_scope) as projection_scope:
|
||||
with vs.variable_scope(unit_scope):
|
||||
w_f_diag = vs.get_variable(
|
||||
"w_f_diag", shape=[self._num_units], dtype=dtype)
|
||||
"w_f_diag", shape=[self._num_units], dtype=dtype)
|
||||
w_i_diag = vs.get_variable(
|
||||
"w_i_diag", shape=[self._num_units], dtype=dtype)
|
||||
"w_i_diag", shape=[self._num_units], dtype=dtype)
|
||||
w_o_diag = vs.get_variable(
|
||||
"w_o_diag", shape=[self._num_units], dtype=dtype)
|
||||
"w_o_diag", shape=[self._num_units], dtype=dtype)
|
||||
|
||||
if self._use_peepholes:
|
||||
c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
|
||||
sigmoid(i + w_i_diag * c_prev) * self._activation(j))
|
||||
c = (
|
||||
sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
|
||||
sigmoid(i + w_i_diag * c_prev) * self._activation(j))
|
||||
else:
|
||||
c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) *
|
||||
self._activation(j))
|
||||
c = (
|
||||
sigmoid(f + self._forget_bias) * c_prev +
|
||||
sigmoid(i) * self._activation(j))
|
||||
|
||||
if self._layer_norm:
|
||||
c = _norm(self._norm_gain, self._norm_shift, c, "state")
|
||||
@ -2608,7 +2620,7 @@ class LayerNormLSTMCell(rnn_cell_impl.RNNCell):
|
||||
m = sigmoid(o) * self._activation(c)
|
||||
|
||||
if self._num_proj is not None:
|
||||
with vs.variable_scope("projection") as proj_scope:
|
||||
with vs.variable_scope("projection"):
|
||||
m = self._linear(m, self._num_proj, bias=False)
|
||||
|
||||
if self._proj_clip is not None:
|
||||
|
@ -192,7 +192,8 @@ class _BaseAttentionMechanism(AttentionMechanism):
|
||||
raise TypeError("probability_fn must be callable, saw type: %s" %
|
||||
type(probability_fn).__name__)
|
||||
if score_mask_value is None:
|
||||
score_mask_value = dtypes.as_dtype(self._memory_layer.dtype).as_numpy_dtype(-np.inf)
|
||||
score_mask_value = dtypes.as_dtype(
|
||||
self._memory_layer.dtype).as_numpy_dtype(-np.inf)
|
||||
self._probability_fn = lambda score, prev: ( # pylint:disable=g-long-lambda
|
||||
probability_fn(
|
||||
_maybe_mask_score(score, memory_sequence_length, score_mask_value),
|
||||
@ -1145,7 +1146,9 @@ class AttentionWrapper(rnn_cell_impl.RNNCell):
|
||||
% (len(attention_layer_sizes), len(attention_mechanisms)))
|
||||
self._attention_layers = tuple(
|
||||
layers_core.Dense(
|
||||
attention_layer_size, name="attention_layer", use_bias=False,
|
||||
attention_layer_size,
|
||||
name="attention_layer",
|
||||
use_bias=False,
|
||||
dtype=attention_mechanisms[i].dtype)
|
||||
for i, attention_layer_size in enumerate(attention_layer_sizes))
|
||||
self._attention_layer_size = sum(attention_layer_sizes)
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#ifdef TENSORFLOW_USE_VERBS
|
||||
|
||||
#include "tensorflow/contrib/verbs/rdma.h"
|
||||
#include <fcntl.h>
|
||||
#include <cstdlib>
|
||||
#include <fcntl.h>
|
||||
#include "tensorflow/contrib/verbs/verbs_util.h"
|
||||
@ -137,7 +138,7 @@ ibv_device* set_device() {
|
||||
if (!env_p_rdma_device.empty()) {
|
||||
for (device_index = 0; device_index < dev_num; device_index++) {
|
||||
if (!env_p_rdma_device.compare(
|
||||
ibv_get_device_name(dev_list[device_index]))) {
|
||||
ibv_get_device_name(dev_list[device_index]))) {
|
||||
CHECK(get_dev_active_port_count(dev_list[device_index]) != 0)
|
||||
<< "Device " << ibv_get_device_name(dev_list[device_index])
|
||||
<< " has no active ports";
|
||||
@ -147,7 +148,7 @@ ibv_device* set_device() {
|
||||
// check validity of input device
|
||||
CHECK(false) << "The device " << env_p_rdma_device << " wasn't found";
|
||||
} else {
|
||||
// set default device
|
||||
// set default device
|
||||
str_port_num = get_env_var("RDMA_DEVICE_PORT");
|
||||
CHECK(str_port_num.empty())
|
||||
<< "RDMA_DEVICE should be provided if RDMA_DEVICE_PORT is set by user";
|
||||
@ -177,7 +178,7 @@ ibv_device* set_device() {
|
||||
// Returns:
|
||||
// port to use
|
||||
uint8_t set_port(ibv_context* context) {
|
||||
uint8_t port_num = 0; //0 is illegal port number
|
||||
uint8_t port_num = 0; // 0 is illegal port number
|
||||
string str_port_num;
|
||||
ibv_device_attr device_att;
|
||||
ibv_port_attr port_attr;
|
||||
@ -199,9 +200,7 @@ uint8_t set_port(ibv_context* context) {
|
||||
// check if port id active
|
||||
CHECK(port_attr.state == IBV_PORT_ACTIVE)
|
||||
<< "Selected RDMA_DEVICE_PORT is not active";
|
||||
}
|
||||
// set default port
|
||||
else {
|
||||
} else { // set default port
|
||||
for (port_index = 1; port_index <= device_att.phys_port_cnt; port_index++) {
|
||||
rc = ibv_query_port(context, port_index, &port_attr);
|
||||
CHECK(!rc) << "Failed to query the port" << port_index;
|
||||
@ -269,7 +268,7 @@ bool is_gid_type_roce_v2(ibv_context* context, uint8_t port_num,
|
||||
// Function to set GID index.
|
||||
// If the port link is IB, no GID index should be selected.
|
||||
// If Ethernet but RDMA_GID_INDEX not set gid index that supports
|
||||
// RoCE V2 will be chosen(fails if more then one IP is configured)
|
||||
// RoCE V2 will be chosen(fails if more than one IP is configured)
|
||||
// Args:
|
||||
// context - device context
|
||||
// port_num - port number
|
||||
@ -302,7 +301,7 @@ uint8_t set_gid(uint8_t port_num, ibv_context* context) {
|
||||
}
|
||||
}
|
||||
switch (port_attr.link_layer) {
|
||||
case(IBV_LINK_LAYER_ETHERNET) :
|
||||
case (IBV_LINK_LAYER_ETHERNET):
|
||||
gid_str = get_env_var("RDMA_GID_INDEX");
|
||||
if (!gid_str.empty()) {
|
||||
gid_index = stoi(gid_str);
|
||||
@ -313,7 +312,7 @@ uint8_t set_gid(uint8_t port_num, ibv_context* context) {
|
||||
<< "More than one IP is available, please specify GID_INDEX";
|
||||
}
|
||||
break;
|
||||
case(IBV_LINK_LAYER_INFINIBAND) : // no need in GID index
|
||||
case (IBV_LINK_LAYER_INFINIBAND): // no need in GID index
|
||||
break;
|
||||
default:
|
||||
LOG(INFO) << "Unknown port link layer. Currently supporting Ethernet and "
|
||||
@ -374,7 +373,8 @@ enum ibv_mtu set_mtu(uint8_t port_num, ibv_context* context) {
|
||||
break;
|
||||
default:
|
||||
CHECK(0) << "Error: MTU input value must be one of the following: 256, "
|
||||
"512, 1024, 2048, 4096. MTU " << mtu << " is invalid\n";
|
||||
"512, 1024, 2048, 4096. MTU "
|
||||
<< mtu << " is invalid\n";
|
||||
break;
|
||||
}
|
||||
CHECK(mtu < port_attr.active_mtu)
|
||||
|
@ -921,7 +921,7 @@ Status InferenceContext::Add(DimensionHandle first, DimensionOrConstant second,
|
||||
if (first_value == 0) {
|
||||
*out = MakeDim(second);
|
||||
} else if (second_value == 0) {
|
||||
*out = MakeDim(first);
|
||||
*out = first;
|
||||
} else if (first_value == kUnknownDim || second_value == kUnknownDim) {
|
||||
*out = UnknownDim();
|
||||
} else {
|
||||
@ -946,7 +946,7 @@ Status InferenceContext::Subtract(DimensionHandle first,
|
||||
const int64 second_value = Value(second);
|
||||
// Special cases.
|
||||
if (second_value == 0) {
|
||||
*out = MakeDim(first);
|
||||
*out = first;
|
||||
} else if (first_value == kUnknownDim || second_value == kUnknownDim) {
|
||||
*out = UnknownDim();
|
||||
} else {
|
||||
|
@ -455,7 +455,6 @@ class Graph {
|
||||
// the corresponding NodeDef to reflect the change.
|
||||
// REQUIRES: The control edge must exist.
|
||||
void RemoveControlEdge(const Edge* e);
|
||||
|
||||
// Updates the input to a node. The existing edge to `dst` is removed and an
|
||||
// edge from `new_src` to `dst` is created. The NodeDef associated with `dst`
|
||||
// is also updated.
|
||||
|
@ -118,11 +118,9 @@ class GraphTest : public ::testing::Test {
|
||||
LOG(FATAL) << name;
|
||||
}
|
||||
|
||||
bool ControlEdgeExistsInGraphOrNodeDef(const Node* src,
|
||||
const Node* dst) {
|
||||
for (const Edge *e : dst->in_edges()) {
|
||||
if (e->IsControlEdge() &&
|
||||
e->src() == src &&
|
||||
bool ControlEdgeExistsInGraphOrNodeDef(const Node* src, const Node* dst) {
|
||||
for (const Edge* e : dst->in_edges()) {
|
||||
if (e->IsControlEdge() && e->src() == src &&
|
||||
e->src_output() == Graph::kControlSlot &&
|
||||
e->dst_input() == Graph::kControlSlot) {
|
||||
return true;
|
||||
|
@ -702,12 +702,16 @@ Status GraphProperties::UpdateShapes(SymbolicShapeRefiner* shape_refiner,
|
||||
Status GraphProperties::PropagateShapes(
|
||||
SymbolicShapeRefiner* shape_refiner, bool relax, TopoQueue* new_shapes,
|
||||
const std::unordered_map<const Node*, std::unordered_set<const Node*>>&
|
||||
resources) const {
|
||||
resources,
|
||||
int num_loops) const {
|
||||
// Limit the number of iterations to prevent infinite loops in the presence of
|
||||
// incorrect shape functions. The algoritm should converge in at most
|
||||
// num_nested_loops^2 * max_rank. We approximate max_rank with the constant 4.
|
||||
// The same applies to resources.
|
||||
const int64 num_loops = new_shapes->size();
|
||||
VLOG(1) << "Propagating (relax=" << relax << ") " << new_shapes->size()
|
||||
<< " new shapes through " << num_loops << " loops and "
|
||||
<< resources.size() << " resources" << std::endl;
|
||||
|
||||
const int64 max_loop_length = item_.graph.node_size();
|
||||
const int64 max_rank = 4;
|
||||
const int64 max_loop_iterations =
|
||||
@ -721,9 +725,12 @@ Status GraphProperties::PropagateShapes(
|
||||
while (!new_shapes->empty() &&
|
||||
num_loop_iterations++ < max_loop_iterations) {
|
||||
const Node* n = new_shapes->pop();
|
||||
for (const Node* fanout : n->out_nodes()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
UpdateShapes(shape_refiner, relax, fanout, new_shapes));
|
||||
for (const Edge* e : n->out_edges()) {
|
||||
if (!e->IsControlEdge()) {
|
||||
const Node* fanout = e->dst();
|
||||
TF_RETURN_IF_ERROR(
|
||||
UpdateShapes(shape_refiner, relax, fanout, new_shapes));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -818,6 +825,7 @@ Status GraphProperties::InferStatically() {
|
||||
std::unordered_map<const Node*, std::unordered_set<const Node*>> resources;
|
||||
std::unordered_set<const Node*> enter_nodes;
|
||||
std::unordered_set<const Node*> merge_nodes;
|
||||
int num_loops = 0;
|
||||
for (const Node* const node : graph.nodes()) {
|
||||
for (int i = 0; i < node->num_inputs(); ++i) {
|
||||
if (node->input_type(i) == DataType::DT_RESOURCE) {
|
||||
@ -830,6 +838,8 @@ Status GraphProperties::InferStatically() {
|
||||
enter_nodes.insert(node);
|
||||
} else if (node->IsMerge()) {
|
||||
merge_nodes.insert(node);
|
||||
} else if (node->IsNextIteration()) {
|
||||
++num_loops;
|
||||
}
|
||||
}
|
||||
|
||||
@ -853,7 +863,7 @@ Status GraphProperties::InferStatically() {
|
||||
}
|
||||
// Propagate shapes normally.
|
||||
TF_RETURN_IF_ERROR(
|
||||
PropagateShapes(&refiner, relax, &new_shapes, resources));
|
||||
PropagateShapes(&refiner, relax, &new_shapes, resources, num_loops));
|
||||
}
|
||||
|
||||
// Track shapes globally across the graph.
|
||||
@ -906,6 +916,9 @@ Status GraphProperties::InferStatically() {
|
||||
&input_properties[i]);
|
||||
}
|
||||
for (const auto& edge : node->in_edges()) {
|
||||
if (edge->IsControlEdge()) {
|
||||
continue;
|
||||
}
|
||||
if (!edge->src()->IsConstant()) {
|
||||
continue;
|
||||
}
|
||||
|
@ -108,7 +108,8 @@ class GraphProperties {
|
||||
Status PropagateShapes(
|
||||
SymbolicShapeRefiner* shape_refiner, bool relax, TopoQueue* new_shapes,
|
||||
const std::unordered_map<const Node*, std::unordered_set<const Node*>>&
|
||||
resources) const;
|
||||
resources,
|
||||
int num_loops) const;
|
||||
};
|
||||
|
||||
} // end namespace grappler
|
||||
|
@ -24,64 +24,40 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
bool IsAdd(const NodeDef& node) {
|
||||
const auto op = node.op();
|
||||
return op == "Add";
|
||||
}
|
||||
bool IsAdd(const NodeDef& node) { return node.op() == "Add"; }
|
||||
|
||||
bool IsAddN(const NodeDef& node) {
|
||||
const auto op = node.op();
|
||||
return op == "AddN";
|
||||
}
|
||||
bool IsAddN(const NodeDef& node) { return node.op() == "AddN"; }
|
||||
|
||||
bool IsAvgPoolGrad(const NodeDef& node) {
|
||||
const auto op = node.op();
|
||||
return op == "AvgPoolGrad";
|
||||
}
|
||||
bool IsAvgPoolGrad(const NodeDef& node) { return node.op() == "AvgPoolGrad"; }
|
||||
|
||||
bool IsBiasAddGrad(const NodeDef& node) {
|
||||
const auto op = node.op();
|
||||
return op == "BiasAddGrad";
|
||||
}
|
||||
bool IsAssert(const NodeDef& node) { return node.op() == "Assert"; }
|
||||
|
||||
bool IsConcatOffset(const NodeDef& node) {
|
||||
const auto op = node.op();
|
||||
return op == "ConcatOffset";
|
||||
}
|
||||
bool IsBiasAddGrad(const NodeDef& node) { return node.op() == "BiasAddGrad"; }
|
||||
|
||||
bool IsConstant(const NodeDef& node) {
|
||||
const auto op = node.op();
|
||||
return op == "Const";
|
||||
}
|
||||
bool IsConcatOffset(const NodeDef& node) { return node.op() == "ConcatOffset"; }
|
||||
|
||||
bool IsConv2D(const NodeDef& node) {
|
||||
const auto op = node.op();
|
||||
return op == "Conv2D";
|
||||
}
|
||||
bool IsConstant(const NodeDef& node) { return node.op() == "Const"; }
|
||||
|
||||
bool IsConv2D(const NodeDef& node) { return node.op() == "Conv2D"; }
|
||||
|
||||
bool IsConv2DBackpropFilter(const NodeDef& node) {
|
||||
const auto op = node.op();
|
||||
return op == "Conv2DBackpropFilter";
|
||||
return node.op() == "Conv2DBackpropFilter";
|
||||
}
|
||||
|
||||
bool IsConv2DBackpropInput(const NodeDef& node) {
|
||||
const auto op = node.op();
|
||||
return op == "Conv2DBackpropInput";
|
||||
return node.op() == "Conv2DBackpropInput";
|
||||
}
|
||||
|
||||
bool IsDepthwiseConv2dNative(const NodeDef& node) {
|
||||
const auto op = node.op();
|
||||
return op == "DepthwiseConv2dNative";
|
||||
return node.op() == "DepthwiseConv2dNative";
|
||||
}
|
||||
|
||||
bool IsDepthwiseConv2dNativeBackpropFilter(const NodeDef& node) {
|
||||
const auto op = node.op();
|
||||
return op == "DepthwiseConv2dNativeBackpropFilter";
|
||||
return node.op() == "DepthwiseConv2dNativeBackpropFilter";
|
||||
}
|
||||
|
||||
bool IsDepthwiseConv2dNativeBackpropInput(const NodeDef& node) {
|
||||
const auto op = node.op();
|
||||
return op == "DepthwiseConv2dNativeBackpropInput";
|
||||
return node.op() == "DepthwiseConv2dNativeBackpropInput";
|
||||
}
|
||||
|
||||
bool IsDequeueOp(const NodeDef& node) {
|
||||
@ -101,14 +77,10 @@ bool IsExit(const NodeDef& node) {
|
||||
return op == "Exit" || op == "RefExit";
|
||||
}
|
||||
|
||||
bool IsFloorMod(const NodeDef& node) {
|
||||
const auto& op = node.op();
|
||||
return op == "FloorMod";
|
||||
}
|
||||
bool IsFloorMod(const NodeDef& node) { return node.op() == "FloorMod"; }
|
||||
|
||||
bool IsFusedBatchNormGradV1(const NodeDef& node) {
|
||||
const auto& op = node.op();
|
||||
return op == "FusedBatchNormGrad";
|
||||
return node.op() == "FusedBatchNormGrad";
|
||||
}
|
||||
|
||||
bool IsIdentity(const NodeDef& node) {
|
||||
@ -121,25 +93,16 @@ bool IsMerge(const NodeDef& node) {
|
||||
return op == "Merge" || op == "RefMerge";
|
||||
}
|
||||
|
||||
bool IsMul(const NodeDef& node) {
|
||||
const auto op = node.op();
|
||||
return op == "Mul";
|
||||
}
|
||||
bool IsMul(const NodeDef& node) { return node.op() == "Mul"; }
|
||||
|
||||
bool IsNoOp(const NodeDef& node) {
|
||||
const auto op = node.op();
|
||||
return op == "NoOp";
|
||||
}
|
||||
bool IsNoOp(const NodeDef& node) { return node.op() == "NoOp"; }
|
||||
|
||||
bool IsNextIteration(const NodeDef& node) {
|
||||
const auto& op = node.op();
|
||||
return op == "NextIteration" || op == "RefNextIteration";
|
||||
}
|
||||
|
||||
bool IsPad(const NodeDef& node) {
|
||||
const auto op = node.op();
|
||||
return op == "Pad";
|
||||
}
|
||||
bool IsPad(const NodeDef& node) { return node.op() == "Pad"; }
|
||||
|
||||
bool IsPlaceholder(const NodeDef& node) {
|
||||
const auto op = node.op();
|
||||
@ -147,20 +110,11 @@ bool IsPlaceholder(const NodeDef& node) {
|
||||
op == "PlaceholderWithDefault";
|
||||
}
|
||||
|
||||
bool IsRealDiv(const NodeDef& node) {
|
||||
const auto op = node.op();
|
||||
return op == "RealDiv";
|
||||
}
|
||||
bool IsRealDiv(const NodeDef& node) { return node.op() == "RealDiv"; }
|
||||
|
||||
bool IsReluGrad(const NodeDef& node) {
|
||||
const auto op = node.op();
|
||||
return op == "ReluGrad";
|
||||
}
|
||||
bool IsReluGrad(const NodeDef& node) { return node.op() == "ReluGrad"; }
|
||||
|
||||
bool IsRecv(const NodeDef& node) {
|
||||
const auto op = node.op();
|
||||
return op == "_Recv";
|
||||
}
|
||||
bool IsRecv(const NodeDef& node) { return node.op() == "_Recv"; }
|
||||
|
||||
bool IsReduction(const NodeDef& node) {
|
||||
const auto& op = node.op();
|
||||
@ -175,53 +129,34 @@ bool IsRestore(const NodeDef& node) {
|
||||
node.op() == "RestoreSlice");
|
||||
}
|
||||
|
||||
bool IsSend(const NodeDef& node) {
|
||||
const auto op = node.op();
|
||||
return op == "_Send";
|
||||
}
|
||||
bool IsSend(const NodeDef& node) { return node.op() == "_Send"; }
|
||||
|
||||
bool IsSlice(const NodeDef& node) {
|
||||
const auto op = node.op();
|
||||
return op == "Slice";
|
||||
}
|
||||
bool IsSlice(const NodeDef& node) { return node.op() == "Slice"; }
|
||||
|
||||
bool IsSquaredDifference(const NodeDef& node) {
|
||||
const auto op = node.op();
|
||||
return op == "SquaredDifference";
|
||||
return node.op() == "SquaredDifference";
|
||||
}
|
||||
|
||||
bool IsSqueeze(const NodeDef& node) {
|
||||
const auto op = node.op();
|
||||
return op == "Squeeze";
|
||||
}
|
||||
bool IsSqueeze(const NodeDef& node) { return node.op() == "Squeeze"; }
|
||||
|
||||
bool IsStopGradient(const NodeDef& node) {
|
||||
const auto& op = node.op();
|
||||
return op == "StopGradient" || op == "PreventGradient";
|
||||
}
|
||||
|
||||
bool IsSub(const NodeDef& node) {
|
||||
const auto op = node.op();
|
||||
return op == "Sub";
|
||||
}
|
||||
bool IsSub(const NodeDef& node) { return node.op() == "Sub"; }
|
||||
|
||||
bool IsSum(const NodeDef& node) {
|
||||
const auto op = node.op();
|
||||
return op == "Sum";
|
||||
}
|
||||
bool IsSum(const NodeDef& node) { return node.op() == "Sum"; }
|
||||
|
||||
bool IsSwitch(const NodeDef& node) {
|
||||
const auto& op = node.op();
|
||||
return op == "Switch" || op == "RefSwitch";
|
||||
}
|
||||
|
||||
bool IsTranspose(const NodeDef& node) {
|
||||
const auto op = node.op();
|
||||
return op == "Transpose";
|
||||
}
|
||||
bool IsTranspose(const NodeDef& node) { return node.op() == "Transpose"; }
|
||||
|
||||
bool IsVariable(const NodeDef& node) {
|
||||
const auto op = node.op();
|
||||
const auto& op = node.op();
|
||||
return op == "Variable" || op == "VariableV2" || op == "AutoReloadVariable" ||
|
||||
op == "VarHandleOp" || op == "ReadVariableOp";
|
||||
}
|
||||
|
@ -25,6 +25,7 @@ namespace grappler {
|
||||
bool IsAdd(const NodeDef& node);
|
||||
bool IsAddN(const NodeDef& node);
|
||||
bool IsAvgPoolGrad(const NodeDef& node);
|
||||
bool IsAssert(const NodeDef& node);
|
||||
bool IsBiasAddGrad(const NodeDef& node);
|
||||
bool IsConcatOffset(const NodeDef& node);
|
||||
bool IsConstant(const NodeDef& node);
|
||||
|
@ -448,6 +448,10 @@ bool ArithmeticOptimizer::CanDedup(const NodeDef& node) const {
|
||||
if (node.device().find("SPU") != string::npos) {
|
||||
return false;
|
||||
}
|
||||
// Workaround for Assert mistakenly being labeled as stateful.
|
||||
if (IsAssert(node)) {
|
||||
return true;
|
||||
}
|
||||
return IsFreeOfSideEffect(node);
|
||||
}
|
||||
|
||||
|
@ -81,6 +81,38 @@ TEST_F(ArithmeticOptimizerTest, OpDedupping) {
|
||||
EXPECT_EQ("c1", new_mul.input(1));
|
||||
}
|
||||
|
||||
TEST_F(ArithmeticOptimizerTest, OpDeduppingAssertAndCheckNumerics) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output p = ops::Placeholder(s, DT_BOOL, ops::Placeholder::Shape({}));
|
||||
Output c = ops::Const(s.WithOpName("c"), {3.14, 2.7}, {1, 2});
|
||||
auto check1 = ops::CheckNumerics(s.WithOpName("check1"), c, "foo");
|
||||
auto check2 = ops::CheckNumerics(s.WithOpName("check2"), c, "foo");
|
||||
auto assert1 = ops::Assert(s.WithOpName("assert1"), p, {c});
|
||||
auto assert2 = ops::Assert(s.WithOpName("assert2"), p, {c});
|
||||
Output mul = ops::Multiply(s.WithOpName("mul").WithControlDependencies(
|
||||
{assert1.operation, assert2.operation}),
|
||||
check1, check2);
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
ArithmeticOptimizer optimizer;
|
||||
GraphDef output;
|
||||
Status status = optimizer.Optimize(nullptr, item, &output);
|
||||
TF_EXPECT_OK(status);
|
||||
// Run the optimizer twice to make sure the rewrite is idempotent.
|
||||
item.graph.Swap(&output);
|
||||
status = optimizer.Optimize(nullptr, item, &output);
|
||||
TF_EXPECT_OK(status);
|
||||
|
||||
EXPECT_EQ(5, output.node_size());
|
||||
const NodeDef& new_mul = output.node(3);
|
||||
EXPECT_EQ(4, new_mul.input_size());
|
||||
EXPECT_EQ("check1", new_mul.input(0));
|
||||
EXPECT_EQ("check1", new_mul.input(1));
|
||||
EXPECT_EQ("^assert1", new_mul.input(2));
|
||||
EXPECT_EQ("^assert1", new_mul.input(3));
|
||||
}
|
||||
|
||||
TEST_F(ArithmeticOptimizerTest, OpDedupCommutative) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output c1 = ops::Const(s.WithOpName("c1"), {1.0f, 2.0f}, {1, 2});
|
||||
|
@ -1720,6 +1720,7 @@ tf_cuda_cc_tests(
|
||||
":data_flow",
|
||||
":ops_testutil",
|
||||
":ops_util",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
|
@ -97,8 +97,9 @@ class BincountOp : public OpKernel {
|
||||
const Tensor& weights_t = ctx->input(2);
|
||||
|
||||
int32 size = size_tensor.scalar<int32>()();
|
||||
OP_REQUIRES(ctx, size >= 0, errors::InvalidArgument(
|
||||
"size (", size, ") must be non-negative"));
|
||||
OP_REQUIRES(
|
||||
ctx, size >= 0,
|
||||
errors::InvalidArgument("size (", size, ") must be non-negative"));
|
||||
|
||||
const auto arr = arr_t.flat<int32>();
|
||||
const auto weights = weights_t.flat<T>();
|
||||
|
@ -16,11 +16,11 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_BINCOUNT_OP_H_
|
||||
#define TENSORFLOW_BINCOUNT_OP_H_
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
|
@ -17,12 +17,12 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "tensorflow/core/kernels/bincount_op.h"
|
||||
#include "external/cub_archive/cub/device/device_histogram.cuh"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/kernels/bincount_op.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/cuda_kernel_helper.h"
|
||||
@ -93,8 +93,8 @@ struct BincountFunctor<GPUDevice, T> {
|
||||
/* num_samples */ num_samples,
|
||||
/* stream */ stream);
|
||||
if (err != cudaSuccess) {
|
||||
return errors::Internal("Could not launch HistogramEven: ",
|
||||
cudaGetErrorString(err), ".");
|
||||
return errors::Internal(
|
||||
"Could not launch HistogramEven: ", cudaGetErrorString(err), ".");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -30,8 +30,8 @@ static Graph* Bincount(int arr_size, int nbins) {
|
||||
Tensor arr(DT_INT32, TensorShape({arr_size}));
|
||||
arr.flat<int32>() = arr.flat<int32>().setRandom().abs();
|
||||
|
||||
Tensor size(DT_INT32, TensorShape({(int32)1}));
|
||||
size.flat<int32>()(0) = (int32)nbins;
|
||||
Tensor size(DT_INT32, TensorShape({static_cast<int32>(1)}));
|
||||
size.flat<int32>()(0) = static_cast<int32>(nbins);
|
||||
|
||||
Tensor weights(DT_INT32, TensorShape({0}));
|
||||
|
||||
|
@ -77,10 +77,10 @@ struct BucketizeFunctor<GPUDevice, T> {
|
||||
TF_RETURN_IF_ERROR(boundaries_array.Finalize());
|
||||
|
||||
CudaLaunchConfig config = GetCudaLaunchConfig(input.size(), d);
|
||||
BucketizeCustomKernel<
|
||||
T><<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||
input.size(), input.data(), boundaries_vector.size(),
|
||||
boundaries_array.data(), output.data());
|
||||
BucketizeCustomKernel<T>
|
||||
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||
input.size(), input.data(), boundaries_vector.size(),
|
||||
boundaries_array.data(), output.data());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -1101,29 +1101,27 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
|
||||
bool cudnn_use_autotune_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
#define REGISTER_GPU_KERNEL(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Conv3DBackpropInput").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
|
||||
Conv3DBackpropInputOp<GPUDevice, T>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.HostMemory("input_sizes"), \
|
||||
Conv3DBackpropInputOp<GPUDevice, T>); \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.HostMemory("input_sizes"), \
|
||||
Conv3DBackpropInputOp<GPUDevice, T>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Conv3DBackpropFilter").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
|
||||
Conv3DBackpropFilterOp<GPUDevice, T>); \
|
||||
Name("Conv3DBackpropFilter").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
|
||||
Conv3DBackpropFilterOp<GPUDevice, T>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.HostMemory("filter_sizes"), \
|
||||
Conv3DBackpropFilterOp<GPUDevice, T>);
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.HostMemory("filter_sizes"), \
|
||||
Conv3DBackpropFilterOp<GPUDevice, T>);
|
||||
TF_CALL_half(REGISTER_GPU_KERNEL);
|
||||
TF_CALL_float(REGISTER_GPU_KERNEL);
|
||||
#undef REGISTER_GPU_KERNEL
|
||||
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -1,4 +1,4 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -22,7 +22,7 @@ REGISTER4(UnaryOp, CPU, "Asinh", functor::asinh, float, double,
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
REGISTER2(UnaryOp, SYCL, "Asinh", functor::asinh, float, double);
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER2(UnaryOp, GPU, "Asinh", functor::asinh, float, double);
|
||||
|
@ -22,7 +22,7 @@ REGISTER4(UnaryOp, CPU, "Atanh", functor::atanh, float, double,
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
REGISTER2(UnaryOp, SYCL, "Atanh", functor::atanh, float, double);
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER2(UnaryOp, GPU, "Atanh", functor::atanh, float, double);
|
||||
|
@ -231,7 +231,8 @@ static void CopyOutputBackpropRegion(const DepthwiseArgs& args,
|
||||
}
|
||||
// Pad to vector-register width (if needed).
|
||||
for (int64 d = 0; d < pad_size; ++d) {
|
||||
buffer[buf_base + vectorized_size + scalar_size + d] = static_cast<T>(0);
|
||||
buffer[buf_base + vectorized_size + scalar_size + d] =
|
||||
static_cast<T>(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -510,7 +511,8 @@ static void DepthwiseConvBackpropInputReference(const DepthwiseArgs& args,
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
extern template struct LaunchDepthwiseConvBackpropInputOp<GPUDevice, Eigen::half>;
|
||||
extern template struct LaunchDepthwiseConvBackpropInputOp<GPUDevice,
|
||||
Eigen::half>;
|
||||
extern template struct LaunchDepthwiseConvBackpropInputOp<GPUDevice, float>;
|
||||
extern template struct LaunchDepthwiseConvBackpropInputOp<GPUDevice, double>;
|
||||
|
||||
@ -885,7 +887,8 @@ static void DepthwiseConvBackpropFilterReference(const DepthwiseArgs& args,
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
extern template struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice, Eigen::half>;
|
||||
extern template struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice,
|
||||
Eigen::half>;
|
||||
extern template struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice, float>;
|
||||
extern template struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice, double>;
|
||||
|
||||
|
@ -158,7 +158,8 @@ struct DepthwiseFilterPadOp {
|
||||
}
|
||||
// Pad the remainder of output to vector-register boundary.
|
||||
for (int64 j = 0; j < pad_size; ++j) {
|
||||
padded_filter[output_base + vectorized_size + scalar_size + j] = static_cast<T>(0);
|
||||
padded_filter[output_base + vectorized_size + scalar_size + j] =
|
||||
static_cast<T>(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -73,18 +73,22 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
std::move(other_arguments),
|
||||
&captured_func));
|
||||
|
||||
*output = new Dataset(input, std::move(captured_func), cycle_length,
|
||||
block_length, output_types_, output_shapes_);
|
||||
*output =
|
||||
new Dataset(ctx, input, func_, std::move(captured_func), cycle_length,
|
||||
block_length, output_types_, output_shapes_);
|
||||
}
|
||||
|
||||
private:
|
||||
class Dataset : public DatasetBase {
|
||||
class Dataset : public GraphDatasetBase {
|
||||
public:
|
||||
Dataset(const DatasetBase* input,
|
||||
Dataset(OpKernelContext* ctx, const DatasetBase* input,
|
||||
const NameAttrList& func,
|
||||
std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length,
|
||||
int64 block_length, const DataTypeVector& output_types,
|
||||
const std::vector<PartialTensorShape>& output_shapes)
|
||||
: input_(input),
|
||||
: GraphDatasetBase(ctx),
|
||||
input_(input),
|
||||
func_(func),
|
||||
captured_func_(std::move(captured_func)),
|
||||
cycle_length_(cycle_length),
|
||||
block_length_(block_length),
|
||||
@ -110,13 +114,47 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
string DebugString() override { return "InterleaveDatasetOp::Dataset"; }
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
|
||||
Node** output) const override {
|
||||
TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name()));
|
||||
Node* input_node;
|
||||
TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_node));
|
||||
Node* cycle_length_node;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node));
|
||||
Node* block_length_node;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node));
|
||||
DataTypeVector other_arguments_types;
|
||||
other_arguments_types.reserve(captured_func_->captured_inputs().size());
|
||||
std::vector<NodeBuilder::NodeOut> other_arguments;
|
||||
other_arguments.reserve(captured_func_->captured_inputs().size());
|
||||
for (const Tensor& t : captured_func_->captured_inputs()) {
|
||||
Node* node;
|
||||
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
|
||||
other_arguments.emplace_back(node);
|
||||
other_arguments_types.emplace_back(t.dtype());
|
||||
}
|
||||
AttrValue f;
|
||||
b->BuildAttrValue(func_, &f);
|
||||
AttrValue other_arguments_types_attr;
|
||||
b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
|
||||
|
||||
TF_RETURN_IF_ERROR(b->AddDataset(
|
||||
this,
|
||||
{{0, input_node}, {2, cycle_length_node}, {3, block_length_node}},
|
||||
{{1, other_arguments}},
|
||||
{{"f", f}, {"Targuments", other_arguments_types_attr}}, output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
class Iterator : public DatasetIterator<Dataset> {
|
||||
public:
|
||||
explicit Iterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params),
|
||||
input_impl_(params.dataset->input_->MakeIterator(params.prefix)),
|
||||
current_elements_(params.dataset->cycle_length_) {}
|
||||
current_elements_(params.dataset->cycle_length_),
|
||||
args_list_(params.dataset->cycle_length_) {}
|
||||
|
||||
void AdvanceToNextInCycle() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
block_index_ = 0;
|
||||
@ -150,18 +188,19 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
// We have reached the end of the current element, so move
|
||||
// on to the next element in the cycle.
|
||||
current_elements_[cycle_index_].reset();
|
||||
args_list_[cycle_index_].clear();
|
||||
--num_open_;
|
||||
AdvanceToNextInCycle();
|
||||
} else if (!end_of_input_) {
|
||||
// Get the next element from the input dataset, and create
|
||||
// an iterator from it.
|
||||
std::vector<Tensor> args;
|
||||
TF_RETURN_IF_ERROR(
|
||||
input_impl_->GetNext(ctx, &args, &end_of_input_));
|
||||
TF_RETURN_IF_ERROR(input_impl_->GetNext(
|
||||
ctx, &args_list_[cycle_index_], &end_of_input_));
|
||||
if (!end_of_input_) {
|
||||
TF_RETURN_IF_ERROR(dataset::MakeIteratorFromInputElement(
|
||||
ctx, args, cycle_index_, dataset()->captured_func_.get(),
|
||||
prefix(), ¤t_elements_[cycle_index_]));
|
||||
ctx, args_list_[cycle_index_], cycle_index_,
|
||||
dataset()->captured_func_.get(), prefix(),
|
||||
¤t_elements_[cycle_index_]));
|
||||
++num_open_;
|
||||
}
|
||||
} else {
|
||||
@ -173,11 +212,100 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name("cycle_index"), cycle_index_));
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name("block_index"), block_index_));
|
||||
if (end_of_input_) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name("end_of_input"), ""));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name("num_open"), num_open_));
|
||||
TF_RETURN_IF_ERROR(SaveCurrentElements(writer));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RestoreInternal(OpKernelContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
|
||||
int64 cycle_index;
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(full_name("cycle_index"), &cycle_index));
|
||||
cycle_index_ = size_t(cycle_index);
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(full_name("block_index"), &block_index_));
|
||||
if (reader->Contains(full_name("end_of_input"))) end_of_input_ = true;
|
||||
int64 num_open;
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(full_name("num_open"), &num_open));
|
||||
num_open_ = size_t(num_open);
|
||||
TF_RETURN_IF_ERROR(RestoreCurrentElements(ctx, reader));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
Status SaveCurrentElements(IteratorStateWriter* writer)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
for (int idx = 0; idx < current_elements_.size(); idx++) {
|
||||
if (current_elements_[idx]) {
|
||||
TF_RETURN_IF_ERROR(SaveParent(writer, current_elements_[idx]));
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name(strings::StrCat("args_size[", idx, "]")),
|
||||
args_list_[idx].size()));
|
||||
for (int i = 0; i < args_list_[idx].size(); i++) {
|
||||
TF_RETURN_IF_ERROR(writer->WriteTensor(
|
||||
full_name(strings::StrCat("args_list_[", idx, "][", i, "]")),
|
||||
args_list_[idx][i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RestoreCurrentElements(OpKernelContext* ctx,
|
||||
IteratorStateReader* reader)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
IteratorContext::Params params;
|
||||
params.env = ctx->env();
|
||||
params.runner = *(ctx->runner());
|
||||
IteratorContext iter_ctx(std::move(params));
|
||||
for (int idx = 0; idx < current_elements_.size(); idx++) {
|
||||
if (reader->Contains(
|
||||
full_name(strings::StrCat("args_size[", idx, "]")))) {
|
||||
int64 args_size;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(
|
||||
full_name(strings::StrCat("args_size[", idx, "]")),
|
||||
&args_size));
|
||||
args_list_[idx].resize(args_size);
|
||||
for (int i = 0; i < args_size; i++) {
|
||||
TF_RETURN_IF_ERROR(reader->ReadTensor(
|
||||
full_name(strings::StrCat("args_list_[", idx, "][", i, "]")),
|
||||
&args_list_[idx][i]));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(dataset::MakeIteratorFromInputElement(
|
||||
&iter_ctx, args_list_[idx], idx,
|
||||
dataset()->captured_func_.get(), prefix(),
|
||||
¤t_elements_[idx]));
|
||||
TF_RETURN_IF_ERROR(
|
||||
RestoreParent(ctx, reader, current_elements_[idx]));
|
||||
} else {
|
||||
current_elements_[idx].reset();
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
mutex mu_;
|
||||
const std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
|
||||
std::vector<std::unique_ptr<IteratorBase>> current_elements_
|
||||
GUARDED_BY(mu_);
|
||||
std::vector<std::vector<Tensor>> args_list_ GUARDED_BY(mu_);
|
||||
size_t cycle_index_ GUARDED_BY(mu_) = 0;
|
||||
int64 block_index_ GUARDED_BY(mu_) = 0;
|
||||
bool end_of_input_ GUARDED_BY(mu_) = false;
|
||||
@ -185,6 +313,7 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
};
|
||||
|
||||
const DatasetBase* const input_;
|
||||
const NameAttrList func_;
|
||||
const std::unique_ptr<CapturedFunction> captured_func_;
|
||||
const int64 cycle_length_;
|
||||
const int64 block_length_;
|
||||
|
@ -258,7 +258,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
EnsureOutputAllocated(batch_result, result->return_values);
|
||||
const size_t num_components = result->return_values.size();
|
||||
for (size_t i = 0; i < num_components; ++i) {
|
||||
Tensor tensor = result->return_values[i];
|
||||
const Tensor& tensor = result->return_values[i];
|
||||
Tensor* batch = &(batch_result->output)[i];
|
||||
if (tensor.NumElements() !=
|
||||
(batch->NumElements() / batch->dim_size(0))) {
|
||||
@ -271,6 +271,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
", [batch]: ", batch_shape.DebugString()));
|
||||
break;
|
||||
}
|
||||
// TODO(mrry): Add a version of DoParallelConcat that allows
|
||||
// us to move `tensor` where possible, to speed up string
|
||||
// tensor batching.
|
||||
Status copy_status = ::tensorflow::functor::DoParallelConcat(
|
||||
*dataset()->device_, tensor, offset, batch);
|
||||
if (!copy_status.ok()) {
|
||||
@ -279,6 +282,11 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
}
|
||||
}
|
||||
// NOTE(mrry): We clear the return values here to release any
|
||||
// memory associated with them and to paralellize the destruction
|
||||
// of the tensors (which can be surprisingly expensive for
|
||||
// map functions with large numbers of return values).
|
||||
result->return_values.clear();
|
||||
batch_result->counter->DecrementCount();
|
||||
});
|
||||
}
|
||||
@ -297,7 +305,10 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
for (size_t i = 0; i < dataset()->batch_size_; ++i) {
|
||||
size_t index = ComputeInvocationIndex(batch_index, i);
|
||||
InvocationResult* result = &invocation_results_[index];
|
||||
*result = InvocationResult();
|
||||
// Reset the state of `result`.
|
||||
// NOTE(mrry): `result->return_values` were cleared when the previous
|
||||
// invocation completed.
|
||||
result->status = Status::OK();
|
||||
}
|
||||
// Start individual invocations.
|
||||
for (size_t i = 0; i < dataset()->batch_size_; ++i) {
|
||||
|
@ -359,7 +359,8 @@ class MaxPoolingGradOp<Eigen::GpuDevice, T> : public OpKernel {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
|
||||
|
||||
use_dnn_ = CanUseCudnn();
|
||||
ReadBoolFromEnvVar("TF_ENABLE_MAXPOOL_NANPROP", false, &propagate_nans_);
|
||||
TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_MAXPOOL_NANPROP", false,
|
||||
&propagate_nans_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
@ -888,7 +889,8 @@ class MaxPoolingWithArgmaxOp : public OpKernel {
|
||||
errors::Unimplemented(
|
||||
"Pooling is not yet supported on the batch dimension."));
|
||||
|
||||
ReadBoolFromEnvVar("TF_ENABLE_MAXPOOL_NANPROP", false, &propagate_nans_);
|
||||
TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_MAXPOOL_NANPROP", false,
|
||||
&propagate_nans_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
@ -1052,7 +1054,8 @@ class MaxPoolingNoMaskOp<GPUDevice, T> : public OpKernel {
|
||||
"Pooling is not yet supported on the batch dimension."));
|
||||
use_dnn_ = CanUseCudnn();
|
||||
|
||||
ReadBoolFromEnvVar("TF_ENABLE_MAXPOOL_NANPROP", false, &propagate_nans_);
|
||||
TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_MAXPOOL_NANPROP", false,
|
||||
&propagate_nans_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
@ -1137,7 +1140,8 @@ class MaxPoolingNoMaskV2Op<GPUDevice, T> : public OpKernel {
|
||||
}
|
||||
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
|
||||
use_dnn_ = CanUseCudnn();
|
||||
ReadBoolFromEnvVar("TF_ENABLE_MAXPOOL_NANPROP", false, &propagate_nans_);
|
||||
TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_MAXPOOL_NANPROP", false,
|
||||
&propagate_nans_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
|
@ -405,17 +405,17 @@ bool MaxPoolForwardWithOptionalArgmax<T>::operator()(
|
||||
if (propagate_nans) {
|
||||
MaxPoolForwardNHWC<true>
|
||||
<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
|
||||
kThreadsPerBlock, 0, d.stream()>>>
|
||||
(output_size, bottom_data, height, width, channels, pooled_height,
|
||||
pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l,
|
||||
top_data, mask);
|
||||
kThreadsPerBlock, 0, d.stream()>>>(
|
||||
output_size, bottom_data, height, width, channels, pooled_height,
|
||||
pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l,
|
||||
top_data, mask);
|
||||
} else {
|
||||
MaxPoolForwardNHWC<false>
|
||||
<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
|
||||
kThreadsPerBlock, 0, d.stream()>>>
|
||||
(output_size, bottom_data, height, width, channels, pooled_height,
|
||||
pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l,
|
||||
top_data, mask);
|
||||
kThreadsPerBlock, 0, d.stream()>>>(
|
||||
output_size, bottom_data, height, width, channels, pooled_height,
|
||||
pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l,
|
||||
top_data, mask);
|
||||
}
|
||||
return d.ok();
|
||||
}
|
||||
|
@ -101,8 +101,8 @@ class MklToTfOp : public OpKernel {
|
||||
// Allocate output tensor.
|
||||
TensorShape output_shape = input_shape.GetTfShape();
|
||||
Tensor* output_tensor = NULL;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(input_number,
|
||||
output_shape, &output_tensor));
|
||||
OP_REQUIRES_OK(context, context->allocate_output(
|
||||
input_number, output_shape, &output_tensor));
|
||||
CHECK_NOTNULL(output_tensor);
|
||||
|
||||
// Do we need to reorder Mkl layout into TensorFlow layout?
|
||||
@ -116,13 +116,13 @@ class MklToTfOp : public OpKernel {
|
||||
// If not, just forward input tensor to output tensor.
|
||||
CHECK(output_tensor->CopyFrom(input_tensor, output_shape));
|
||||
}
|
||||
} catch (mkldnn::error &e) {
|
||||
} catch (mkldnn::error& e) {
|
||||
string error_msg = "Status: " + std::to_string(e.status) +
|
||||
", message: " + std::string(e.message) +
|
||||
", in file " + std::string(__FILE__) + ":" +
|
||||
std::to_string(__LINE__);
|
||||
OP_REQUIRES_OK(context,
|
||||
errors::Aborted("Operation received an exception:", error_msg));
|
||||
", message: " + std::string(e.message) + ", in file " +
|
||||
std::string(__FILE__) + ":" + std::to_string(__LINE__);
|
||||
OP_REQUIRES_OK(
|
||||
context,
|
||||
errors::Aborted("Operation received an exception:", error_msg));
|
||||
}
|
||||
}
|
||||
#else
|
||||
@ -160,8 +160,8 @@ class MklToTfOp : public OpKernel {
|
||||
|
||||
// Allocate output tensor.
|
||||
Tensor* output_tensor = NULL;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(input_number,
|
||||
output_shape, &output_tensor));
|
||||
OP_REQUIRES_OK(context, context->allocate_output(input_number, output_shape,
|
||||
&output_tensor));
|
||||
|
||||
dnnLayout_t output_layout =
|
||||
static_cast<dnnLayout_t>(input_shape.GetTfLayout());
|
||||
|
@ -98,6 +98,19 @@ gtl::InlinedVector<T, 8> ComputeStride(const TensorShape& shape) {
|
||||
return strides;
|
||||
}
|
||||
|
||||
// Helper to compute 'strides' given an Eigen TensorDimensions
|
||||
template <typename T, typename EigenDimensions>
|
||||
gtl::InlinedVector<T, 8> ComputeEigenStrides(const EigenDimensions& shape) {
|
||||
const int ndims = shape.rank();
|
||||
gtl::InlinedVector<T, 8> strides(ndims);
|
||||
T stride = 1;
|
||||
for (int i = ndims - 1; i >= 0; --i) {
|
||||
strides[i] = stride;
|
||||
stride *= static_cast<T>(shape[i]);
|
||||
}
|
||||
return strides;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_KERNELS_OPS_UTIL_H_
|
||||
|
@ -181,16 +181,18 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
padding_values.push_back(tensor::DeepCopy(padding_value_t));
|
||||
}
|
||||
|
||||
*output = new Dataset(batch_size, std::move(padded_shapes),
|
||||
*output = new Dataset(ctx, batch_size, std::move(padded_shapes),
|
||||
std::move(padding_values), input);
|
||||
}
|
||||
|
||||
private:
|
||||
class Dataset : public DatasetBase {
|
||||
class Dataset : public GraphDatasetBase {
|
||||
public:
|
||||
Dataset(int64 batch_size, std::vector<PartialTensorShape> padded_shapes,
|
||||
Dataset(OpKernelContext* ctx, int64 batch_size,
|
||||
std::vector<PartialTensorShape> padded_shapes,
|
||||
std::vector<Tensor> padding_values, const DatasetBase* input)
|
||||
: batch_size_(batch_size),
|
||||
: GraphDatasetBase(ctx),
|
||||
batch_size_(batch_size),
|
||||
padded_shapes_(std::move(padded_shapes)),
|
||||
padding_values_(std::move(padding_values)),
|
||||
input_(input) {
|
||||
@ -232,6 +234,47 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
")::Dataset");
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
|
||||
Node** output) const override {
|
||||
Node* input_graph_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node));
|
||||
Node* batch_size = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size));
|
||||
|
||||
std::vector<NodeBuilder::NodeOut> padded_shapes;
|
||||
padded_shapes.reserve(padded_shapes_.size());
|
||||
for (int i = 0; i < padded_shapes_.size(); i++) {
|
||||
Node* node;
|
||||
Tensor t(DT_INT64, TensorShape({padded_shapes_[i].dims()}));
|
||||
for (int j = 0; j < padded_shapes_[i].dims(); j++) {
|
||||
t.vec<int64>()(j) = padded_shapes_[i].dim_size(j);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
|
||||
padded_shapes.emplace_back(node);
|
||||
}
|
||||
|
||||
std::vector<NodeBuilder::NodeOut> padding_values;
|
||||
padding_values.reserve(padding_values_.size());
|
||||
for (const Tensor& t : padding_values_) {
|
||||
Node* node;
|
||||
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
|
||||
padding_values.emplace_back(node);
|
||||
}
|
||||
|
||||
AttrValue output_types;
|
||||
b->BuildAttrValue(output_dtypes(), &output_types);
|
||||
|
||||
AttrValue N;
|
||||
b->BuildAttrValue<int64>(padded_shapes_.size(), &N);
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
b->AddDataset(this, {{0, input_graph_node}, {1, batch_size}},
|
||||
{{2, padded_shapes}, {3, padding_values}},
|
||||
{{"Toutput_types", output_types}, {"N", N}}, output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
// Copies element into the index^th slice of parent (in the 0th dimension).
|
||||
//
|
||||
@ -248,17 +291,25 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
// Each row of `batch_elements` is a tuple of tensors from the
|
||||
// input iterator.
|
||||
std::vector<std::vector<Tensor>> batch_elements;
|
||||
batch_elements.reserve(dataset()->batch_size_);
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
*end_of_sequence = false;
|
||||
for (int i = 0; i < dataset()->batch_size_ && !*end_of_sequence;
|
||||
++i) {
|
||||
std::vector<Tensor> batch_element_tuple;
|
||||
TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &batch_element_tuple,
|
||||
end_of_sequence));
|
||||
if (!*end_of_sequence) {
|
||||
batch_elements.push_back(std::move(batch_element_tuple));
|
||||
if (!input_impl_) {
|
||||
*end_of_sequence = true;
|
||||
return Status::OK();
|
||||
} else {
|
||||
*end_of_sequence = false;
|
||||
batch_elements.reserve(dataset()->batch_size_);
|
||||
for (int i = 0; i < dataset()->batch_size_ && !*end_of_sequence;
|
||||
++i) {
|
||||
std::vector<Tensor> batch_element_tuple;
|
||||
TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &batch_element_tuple,
|
||||
end_of_sequence));
|
||||
if (!*end_of_sequence) {
|
||||
batch_elements.push_back(std::move(batch_element_tuple));
|
||||
}
|
||||
}
|
||||
if (*end_of_sequence) {
|
||||
input_impl_.reset();
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -347,6 +398,28 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
if (input_impl_)
|
||||
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
|
||||
else
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("exhausted"), ""));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RestoreInternal(OpKernelContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(mu_);
|
||||
if (reader->Contains(full_name("exhausted"))) {
|
||||
input_impl_.reset();
|
||||
} else {
|
||||
input_impl_ = dataset()->input_->MakeIterator(prefix());
|
||||
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
mutex mu_;
|
||||
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
|
||||
|
@ -352,13 +352,15 @@ class DeserializeSparseOp : public OpKernel {
|
||||
i, "] was: ", shape.dims() - 1, " but rank of SparseTensor[", i,
|
||||
"] is: ", expanded_tensor_shape.dims() - 1));
|
||||
for (int j = 1; j < shape.dims(); ++j) {
|
||||
OP_REQUIRES(
|
||||
context, shape.dim_size(j) == expanded_tensor_shape.dim_size(j),
|
||||
errors::InvalidArgument(
|
||||
"Inconsistent shape across SparseTensors: dimension ", j - 1,
|
||||
" prior to SparseTensor[", i, "] was: ", shape.dim_size(j),
|
||||
" but rank of SparseTensor[", i,
|
||||
"] is: ", expanded_tensor_shape.dim_size(j)));
|
||||
// NOTE(mrry): For compatibility with the implementations of
|
||||
// DeserializeManySparse, and many ops that generate
|
||||
// SparseTensors to batch that do not have a fixed
|
||||
// dense_shape (e.g. `tf.parse_single_example()`), we
|
||||
// compute the maximum in each dimension to find the
|
||||
// smallest dense_shape that bounds all of the input
|
||||
// SparseTensors.
|
||||
shape.set_dim(j, std::max(shape.dim_size(j),
|
||||
expanded_tensor_shape.dim_size(j)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -29,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/platform/cloud/curl_http_request.h"
|
||||
#include "tensorflow/core/platform/cloud/file_block_cache.h"
|
||||
#include "tensorflow/core/platform/cloud/google_auth_provider.h"
|
||||
@ -696,6 +697,18 @@ Status GcsFileSystem::LoadBufferFromGCS(const string& filename, size_t offset,
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading gs://",
|
||||
bucket, "/", object);
|
||||
|
||||
if (out->size() < block_size()) {
|
||||
// Check stat cache to see if we encountered an interrupted read.
|
||||
FileStatistics stat;
|
||||
if (stat_cache_->Lookup(filename, &stat)) {
|
||||
if (offset + out->size() < stat.length) {
|
||||
return errors::Internal(strings::Printf(
|
||||
"File contents are inconsistent for file: %s @ %lu.",
|
||||
filename.c_str(), offset));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -816,7 +829,8 @@ Status GcsFileSystem::StatForObject(const string& fname, const string& bucket,
|
||||
return errors::Internal("'stat' cannot be nullptr.");
|
||||
}
|
||||
if (object.empty()) {
|
||||
return errors::InvalidArgument("'object' must be a non-empty string.");
|
||||
return errors::InvalidArgument(strings::Printf(
|
||||
"'object' must be a non-empty string. (File: %s)", fname.c_str()));
|
||||
}
|
||||
|
||||
StatCache::ComputeFunc compute_func =
|
||||
|
@ -131,8 +131,8 @@ error::Code ErrnoToCode(int err_number) {
|
||||
case ENETUNREACH: // Network unreachable
|
||||
case ENOLCK: // No locks available
|
||||
case ENOLINK: // Link has been severed
|
||||
#if !(defined(__APPLE__) || defined(__FreeBSD__) || defined(_WIN32) \
|
||||
|| defined(__HAIKU__))
|
||||
#if !(defined(__APPLE__) || defined(__FreeBSD__) || defined(_WIN32) || \
|
||||
defined(__HAIKU__))
|
||||
case ENONET: // Machine is not on the network
|
||||
#endif
|
||||
code = error::UNAVAILABLE;
|
||||
|
@ -37,8 +37,8 @@ limitations under the License.
|
||||
#ifdef TF_USE_SNAPPY
|
||||
#include "snappy.h"
|
||||
#endif
|
||||
#if (defined(__APPLE__) && defined(__MACH__)) || defined(__FreeBSD__) \
|
||||
|| defined(__HAIKU__)
|
||||
#if (defined(__APPLE__) && defined(__MACH__)) || defined(__FreeBSD__) || \
|
||||
defined(__HAIKU__)
|
||||
#include <thread>
|
||||
#endif
|
||||
|
||||
@ -62,8 +62,8 @@ int NumSchedulableCPUs() {
|
||||
}
|
||||
perror("sched_getaffinity");
|
||||
#endif
|
||||
#if (defined(__APPLE__) && defined(__MACH__)) || defined(__FreeBSD__) \
|
||||
|| defined(__HAIKU__)
|
||||
#if (defined(__APPLE__) && defined(__MACH__)) || defined(__FreeBSD__) || \
|
||||
defined(__HAIKU__)
|
||||
unsigned int count = std::thread::hardware_concurrency();
|
||||
if (count > 0) return static_cast<int>(count);
|
||||
#endif
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user