Merge pull request #11796 from vrv/branch_163282839

Branch 163282839
This commit is contained in:
Vijay Vasudevan 2017-07-26 21:15:54 -07:00 committed by GitHub
commit 3f1a4ecdb5
35 changed files with 505 additions and 473 deletions

View File

@ -1944,6 +1944,7 @@ cc_library(
":buffer_liveness",
":hlo",
":hlo_pass",
":hlo_pass_pipeline",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/core:lib",
],

View File

@ -254,13 +254,9 @@ Status CpuCompiler::RunHloPasses(HloModule* module) {
HloPassPipeline pipeline("CPU");
pipeline.AddInvariantChecker<HloVerifier>();
for (const auto& reduce_precision_options :
module->config().debug_options().hlo_reduce_precision_options()) {
if (reduce_precision_options.pass_timing() ==
HloReducePrecisionOptions::BEFORE_OP_FUSION) {
pipeline.AddPass<ReducePrecisionInsertion>(reduce_precision_options);
}
}
ReducePrecisionInsertion::AddPasses(
&pipeline, module->config().debug_options(),
HloReducePrecisionOptions::BEFORE_OP_FUSION);
// TODO(b/35786417): Re-enable inliner pass after fixing the bug and deciding
// where we will take this pass in future.
@ -288,13 +284,9 @@ Status CpuCompiler::RunHloPasses(HloModule* module) {
pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/false);
pipeline.AddPass<CpuInstructionFusion>();
for (const auto& reduce_precision_options :
module->config().debug_options().hlo_reduce_precision_options()) {
if (reduce_precision_options.pass_timing() ==
HloReducePrecisionOptions::AFTER_OP_FUSION) {
pipeline.AddPass<ReducePrecisionInsertion>(reduce_precision_options);
}
}
ReducePrecisionInsertion::AddPasses(
&pipeline, module->config().debug_options(),
HloReducePrecisionOptions::AFTER_OP_FUSION);
pipeline.AddPass<CpuLayoutAssignment>(
module->mutable_entry_computation_layout());

View File

@ -124,15 +124,9 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module,
{
HloPassPipeline pipeline("optimization");
pipeline.AddInvariantChecker<HloVerifier>();
for (const auto& reduce_precision_options :
hlo_module->config().debug_options().hlo_reduce_precision_options()) {
if (reduce_precision_options.pass_timing() ==
HloReducePrecisionOptions::BEFORE_OP_FUSION) {
pipeline.AddPass<ReducePrecisionInsertion>(reduce_precision_options);
}
}
ReducePrecisionInsertion::AddPasses(
&pipeline, hlo_module->config().debug_options(),
HloReducePrecisionOptions::BEFORE_OP_FUSION);
{
auto& pass =
pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification");
@ -162,14 +156,9 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module,
TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status());
HloPassPipeline reduce_pipeline("reduce-precision");
for (const auto& reduce_precision_options :
hlo_module->config().debug_options().hlo_reduce_precision_options()) {
if (reduce_precision_options.pass_timing() ==
HloReducePrecisionOptions::AFTER_OP_FUSION) {
reduce_pipeline.AddPass<ReducePrecisionInsertion>(
reduce_precision_options);
}
}
ReducePrecisionInsertion::AddPasses(
&reduce_pipeline, hlo_module->config().debug_options(),
HloReducePrecisionOptions::AFTER_OP_FUSION);
StatusOr<bool> reduce_result = reduce_pipeline.Run(hlo_module);
TF_RETURN_IF_ERROR(reduce_result.status());

View File

@ -191,12 +191,16 @@ string NodeColorAttributes(ColorScheme color) {
case kYellow:
return make_tuple("filled", "#fff9c4", "#cbc693", "black");
case kDashedBorder:
return make_tuple("dashed", "white", "#757575", "#757575");
// "filled,dashed" looks the same as "dashed", since we have a white
// background. But we use "filled,dashed" so that when you hover over
// any part of the node (not just the text inside the node), our css
// :hover rule is triggered.
return make_tuple("filled,dashed", "white", "#757575", "#757575");
}
}();
return Printf(
R"(style=%s, fontcolor="%s", color="%s", fillcolor="%s")", style,
R"(style="%s", fontcolor="%s", color="%s", fillcolor="%s")", style,
font_color, stroke_color, fill_color);
}
@ -304,6 +308,7 @@ optional<string> MatchTrivialComputation(const HloComputation* computation) {
}
}
// Encapsulates logic for dumping an HLO module to DOT (i.e. graphviz syntax).
class HloDotDumper {
public:
HloDotDumper(const HloComputation* computation, tensorflow::StringPiece label,
@ -329,6 +334,9 @@ class HloDotDumper {
return StrCat("cluster_", reinterpret_cast<uint64>(computation));
}
// Generates graph header/footer. These should be called *after* dumping all
// of the instructions and subcomputations for the graph, as they both use
// data generated while dumping the graph.
string Header();
string Footer();
@ -360,6 +368,24 @@ class HloDotDumper {
const HloExecutionProfile* profile_; // may be null
const NodeFilter filter_;
// Each HloInstruction dumped gets a monotically-increasing node ID. This
// must start at 1, because that's where graphviz's accounting starts.
int64 next_node_id_ = 1;
std::unordered_map<const HloInstruction*, int64> node_ids_;
// Each (from, to) edge gets a monotonically-increasing ID. This is a
// multimap because it's possible for the same edge to appear multiple times
// in the graph (e.g. x^2 may be represented as mul(x, x)).
int64 next_edge_id_ = 1;
std::unordered_multimap<
std::pair<const HloInstruction*, const HloInstruction*>, int64,
tensorflow::hash<std::pair<const HloInstruction*, const HloInstruction*>>>
edge_ids_;
// Each HloComputation that's emitted gets a monotonically-increasing ID.
int64 next_cluster_id_ = 1;
std::unordered_map<const HloComputation*, int64> cluster_ids_;
// Edges to print from Footer(). Edges come at the end because graphviz is
// unhappy if an edge from a subcomputation to a node in the outer computation
// appears before both the inner computation and the destination node are
@ -368,25 +394,32 @@ class HloDotDumper {
};
string HloDotDumper::Dump() {
string g = Header();
string body;
for (const auto& kv : SubcomputationsToDump()) {
const HloComputation* subcomp = kv.first;
const HloInstruction* parent = kv.second;
StrAppend(&g, DumpSubcomputation(subcomp, parent));
StrAppend(&body, DumpSubcomputation(subcomp, parent));
}
StrAppend(&g, DumpComputation(computation_));
StrAppend(&body, DumpComputation(computation_));
// By contract, Header() and Footer() have to be called after we've dumped all
// our instructions, because they use state generated during that process.
string g = Header();
StrAppend(&g, body);
StrAppend(&g, Footer());
return g;
}
string HloDotDumper::Header() {
// DOT graphs accept a stylesheet as a URI. So naturally, an inline
// stylesheet is a data URI!
const char* fmt = R"(digraph G {
rankdir = TB;
compound = true;
label = <<b>%s</b>>;
labelloc = t;
// Disable the tooltip. Interestingly, "" doesn't work!
tooltip = " ";
// DOT graphs accept a stylesheet as a URI. So naturally, an inline
// stylesheet is a data URI!
stylesheet="
data:text/css,
@import url(https://fonts.googleapis.com/css?family=Roboto:400,700);
@ -394,6 +427,8 @@ stylesheet="
font-family: 'Roboto';
font-size: 12px;
}
%s
"
)";
@ -404,7 +439,59 @@ stylesheet="
Appendf(&graph_label, "<br/>total cycles = %lld (%s)", cycles,
tensorflow::strings::HumanReadableNum(cycles));
}
return Printf(fmt, graph_label);
// Create CSS rules that say, when you hover over the given node or cluster,
// turn the given edge the given color.
//
// We rely on a few properties of how graphviz generates SVGs:
//
// - Nodes are named "nodeN", where N corresponds to the 1-based index of
// the node in our DOT (i.e. the first node in the DOT is "node1", etc.).
// Edges are similarly named "edgeN", and clusters are named "clustN".
// - Nodes come before their in- and out-edges in the SVG. We need this
// because the "X ~ Y" CSS selector finds a sibling of X that *comes
// after X in the DOM* and matches Y.
std::vector<string> edge_css_rules;
const char* kBlue = "#1976d2";
const char* kRed = "#d32f2f";
for (const auto& kv : edge_ids_) {
const HloInstruction* from_node = kv.first.first;
const HloInstruction* to_node = kv.first.second;
int64 edge_id = kv.second;
auto add_hover_css_rule = [&](string elem_type, int64 elem_id,
const char* color) {
// One could imagine other ways of writing this CSS rule that involve less
// duplication, but this way seems to be relatively performant.
edge_css_rules.push_back(Printf(
" #%s%d:hover ~ #edge%lld text { fill: %s; }\n"
" #%s%d:hover ~ #edge%lld path { stroke: %s; stroke-width: .2em; }\n"
" #%s%d:hover ~ #edge%lld polygon { "
"fill: %s; stroke: %s; stroke-width: .2em; }\n",
elem_type, elem_id, edge_id, color, //
elem_type, elem_id, edge_id, color, //
elem_type, elem_id, edge_id, color, color));
};
int64 from_node_id = node_ids_.at(from_node);
int64 to_node_id = node_ids_.at(to_node);
add_hover_css_rule("node", from_node_id, kBlue);
add_hover_css_rule("node", to_node_id, kRed);
// If this edge crosses a fusion cluster boundary, highlight it when the
// cluster is hovered over.
if (from_node->IsFused() &&
from_node->fusion_instruction()->fused_expression_root() == from_node) {
int64 cluster_id = cluster_ids_.at(from_node->parent());
add_hover_css_rule("clust", cluster_id, kBlue);
}
if (to_node->IsFused() && to_node->opcode() == HloOpcode::kParameter) {
int64 cluster_id = cluster_ids_.at(to_node->parent());
add_hover_css_rule("clust", cluster_id, kRed);
}
}
return Printf(fmt, graph_label, Join(edge_css_rules, "\n"));
}
string HloDotDumper::Footer() { return StrCat(Join(edges_, "\n"), "\n}"); }
@ -440,11 +527,14 @@ string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp,
%s;
label = <%s>;
labelloc = t;
tooltip = " ";
%s
} // %s
)";
cluster_ids_[subcomp] = next_cluster_id_++;
string id = SubcomputationId(subcomp);
string subcomp_label, style;
@ -475,10 +565,14 @@ labelloc = t;
// belongs to a fusion node, it's drawn in place of the fusion instruction, so
// there's no need to link those.
if (parent_instr->opcode() != HloOpcode::kFusion) {
const char* edge_fmt = R"(%s -> %s [ltail="%s", style="dashed"];)";
edge_ids_.insert(
{{subcomp->root_instruction(), parent_instr}, next_edge_id_++});
const char* edge_fmt =
R"(%s -> %s [ltail="%s", style="dashed" tooltip="%s -> %s"];)";
edges_.push_back(
Printf(edge_fmt, InstructionId(subcomp->root_instruction()),
InstructionId(parent_instr), SubcomputationId(subcomp)));
InstructionId(parent_instr), SubcomputationId(subcomp),
subcomp->name(), parent_instr->name()));
}
return computation;
@ -508,6 +602,8 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
return "";
}
node_ids_[instr] = next_node_id_++;
ColorScheme color = GetInstructionColor(instr);
string node_shape = GetInstructionNodeShape(instr);
string node_label = GetInstructionNodeLabel(instr);
@ -534,8 +630,10 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
}
}
return Printf("%s [label=<%s>, shape=%s, %s];\n", InstructionId(instr),
node_body, node_shape, NodeColorAttributes(color));
return Printf(R"(%s [label=<%s>, shape=%s, tooltip=" ", %s];)"
"\n",
InstructionId(instr), node_body, node_shape,
NodeColorAttributes(color));
}
string HloDotDumper::GetInstructionNodeInlinedConstants(
@ -776,12 +874,15 @@ void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) {
if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant) {
return;
}
string edge = Printf("%s -> %s", InstructionId(from), InstructionId(to));
edge_ids_.insert({{from, to}, next_edge_id_++});
string edge_label;
if (instr->operand_count() > 1) {
Appendf(&edge, R"( [headlabel="%lld",labeldistance=2])", operand_num);
edge_label = Printf(R"( headlabel="%lld", labeldistance=2)", operand_num);
}
StrAppend(&edge, ";");
edges_.push_back(edge);
const char* kEdgeFmt = R"(%s -> %s [tooltip="%s -> %s" %s];)";
edges_.push_back(Printf(kEdgeFmt, InstructionId(from), InstructionId(to),
from->name(), to->name(), edge_label));
};
// Add edges from instr's operands to instr. Parameters within fusion
@ -945,40 +1046,33 @@ NodeFilter MakeNodeFilter(const HloInstruction* root, int64 radius) {
}
auto is_displayed = [&](const HloInstruction* instr) {
return nodes.count(instr) > 0;
// Constants are displayed inline with their users; they're never omitted.
return nodes.count(instr) > 0 || instr->opcode() == HloOpcode::kConstant;
};
// Mark nodes which don't have all of their operands present as "some operands
// omitted".
// Make a second pass over 'nodes' to fix up the NodeFilterResults now that we
// know which nodes will be included in the graph.
for (auto& kv : nodes) {
const HloInstruction* instr = kv.first;
NodeFilterResult& filter_result = kv.second;
const auto& operands = instr->operands();
// Mark nodes with some omitted as "some operands omitted".
if (std::any_of(operands.begin(), operands.end(), is_displayed) &&
!std::all_of(operands.begin(), operands.end(), is_displayed)) {
// Mark nodes with some operands omitted appropriately.
filter_result = kSomeOperandsOmitted;
} else if (!operands.empty() &&
std::none_of(operands.begin(), operands.end(), is_displayed)) {
// Mark nodes with *all* operands omitted appropriately.
filter_result = kOmitNodeOperands;
}
}
// Promote nodes with type kSomeUsersOmitted to kNormalNode if all of their
// users made it into the graph by other means.
for (auto& kv : nodes) {
const auto& users = kv.first->users();
if (kv.second == kSomeUsersOmitted &&
std::all_of(users.begin(), users.end(), is_displayed)) {
kv.second = kNormalNode;
}
}
// If none of a node's operands appear in nodes, mark it as type
// kOmitNodeOperands so it gets styled appropriately.
for (auto& kv : nodes) {
const auto& operands = kv.first->operands();
if (!operands.empty() &&
std::none_of(operands.begin(), operands.end(), is_displayed)) {
kv.second = kOmitNodeOperands;
// Promote nodes with type kSomeUsersOmitted to kNormalNode if all of their
// users made it into the graph.
if (filter_result == kSomeUsersOmitted &&
std::all_of(instr->users().begin(), instr->users().end(),
is_displayed)) {
filter_result = kNormalNode;
}
}

View File

@ -912,7 +912,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
return CreateInfeed(shape, infeed_config());
case HloOpcode::kOutfeed:
CHECK_EQ(new_operands.size(), 1);
return CreateOutfeed(shape, new_operands[0], outfeed_config());
return CreateOutfeed(outfeed_shape_, new_operands[0], outfeed_config());
case HloOpcode::kBatchNormGrad:
CHECK_EQ(new_operands.size(), 5);
return CreateBatchNormGrad(shape, new_operands[0], new_operands[1],

View File

@ -638,6 +638,27 @@ TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) {
metadata, fusion->fused_expression_root()->operand(0)->metadata()));
}
TEST_F(HloInstructionTest, PreserveOutfeedShapeThroughClone) {
HloComputation::Builder builder(TestName());
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR2<float>({
{1, 2},
{3, 4},
})));
auto shape10 = ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0});
auto shape01 = ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {0, 1});
auto outfeed10 = builder.AddInstruction(
HloInstruction::CreateOutfeed(shape10, constant, ""));
auto outfeed01 = builder.AddInstruction(
HloInstruction::CreateOutfeed(shape01, constant, ""));
auto clone01 = builder.AddInstruction(outfeed01->Clone());
auto clone10 = builder.AddInstruction(outfeed10->Clone());
EXPECT_TRUE(ShapeUtil::Equal(clone01->outfeed_shape(), shape01));
EXPECT_TRUE(ShapeUtil::Equal(clone10->outfeed_shape(), shape10));
}
TEST_F(HloInstructionTest, FusionOpWithCalledComputations) {
HloComputation::Builder builder(TestName());
// Create a fusion instruction containing a single unary operation.

View File

@ -92,4 +92,18 @@ HloReducePrecisionOptions ReducePrecisionInsertion::make_options_proto(
return options;
}
bool ReducePrecisionInsertion::AddPasses(
HloPassPipeline* pipeline, const DebugOptions& debug_options,
const HloReducePrecisionOptions::PassTiming pass_timing) {
bool passes_added = false;
for (const auto& pass_options :
debug_options.hlo_reduce_precision_options()) {
if (pass_options.pass_timing() == pass_timing) {
pipeline->AddPass<ReducePrecisionInsertion>(pass_options);
passes_added = true;
}
}
return passes_added;
}
} // namespace xla

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
namespace xla {
@ -72,6 +73,13 @@ class ReducePrecisionInsertion : public HloPassInterface {
const int exponent_bits, const int mantissa_bits,
const OpcodeFilterFunction& should_reduce_output_precision);
// Add ReducePrecisionInsertion passes to an HloPassPipeline based on the list
// of HloReducePrecisionOptions in a DebugOptions proto. Returns true if any
// passes were added.
static bool AddPasses(
HloPassPipeline* pipeline, const DebugOptions& debug_options,
const HloReducePrecisionOptions::PassTiming pass_timing);
private:
// Parameters for the precision reduction to be added.
const int exponent_bits_;

View File

@ -76,6 +76,7 @@ cuda_py_test(
"//tensorflow/python:nn_ops",
"//tensorflow/python:session",
],
tags = ["no_pip"], # contrib/learn:head_test is not available in pip.
)
cuda_py_test(

View File

@ -66,8 +66,10 @@ void DumpTraceToLogDirectory(const tensorflow::string& logdir,
LOG(INFO) << "Dumped trace data to " << path;
}
ProfileResponse Profile(const tensorflow::string& service_addr) {
ProfileResponse Profile(const tensorflow::string& service_addr,
int duration_ms) {
ProfileRequest request;
request.set_duration_ms(duration_ms);
ProfileResponse response;
ClientContext context;
::grpc::ChannelArguments channel_args;
@ -77,7 +79,7 @@ ProfileResponse Profile(const tensorflow::string& service_addr) {
std::unique_ptr<TPUProfiler::Stub> stub =
TPUProfiler::NewStub(::grpc::CreateCustomChannel(
service_addr, ::grpc::InsecureChannelCredentials(), channel_args));
TF_CHECK_OK(FromGrpcStatus(stub->Profile(&context, request, &response)));
TF_QCHECK_OK(FromGrpcStatus(stub->Profile(&context, request, &response)));
return response;
}
@ -88,12 +90,16 @@ ProfileResponse Profile(const tensorflow::string& service_addr) {
int main(int argc, char** argv) {
tensorflow::string FLAGS_service_addr;
tensorflow::string FLAGS_logdir;
int FLAGS_duration_ms = 2000;
std::vector<tensorflow::Flag> flag_list = {
tensorflow::Flag("service_addr", &FLAGS_service_addr,
"Address of TPU profiler service e.g. localhost:8466"),
tensorflow::Flag("logdir", &FLAGS_logdir,
"Path of TensorBoard log directory e.g. /tmp/tb_log"),
tensorflow::Flag("duration_ms", &FLAGS_duration_ms,
"Duration of tracing in ms. Default is 2000ms."),
};
tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
if (!parse_ok || FLAGS_service_addr.empty() || FLAGS_logdir.empty()) {
@ -101,9 +107,23 @@ int main(int argc, char** argv) {
return 2;
}
tensorflow::port::InitMain(argv[0], &argc, &argv);
int duration_ms = FLAGS_duration_ms;
tensorflow::ProfileResponse response =
tensorflow::tpu::Profile(FLAGS_service_addr);
tensorflow::tpu::Profile(FLAGS_service_addr, duration_ms);
// Ignore computation_graph for now.
tensorflow::tpu::DumpTraceToLogDirectory(FLAGS_logdir,
response.encoded_trace());
if (response.encoded_trace().empty()) {
LOG(WARNING) << "No trace event is collected during the " << duration_ms
<< "ms interval.";
} else {
tensorflow::tpu::DumpTraceToLogDirectory(FLAGS_logdir,
response.encoded_trace());
}
// Print this at the end so that it's not buried in irrelevant LOG messages.
std::cout
<< "NOTE: using the trace duration " << duration_ms << "ms." << std::endl
<< "Set an appropriate duration (with --duration_ms) if you "
"don't see a full step in your trace or the captured trace is too "
"large."
<< std::endl;
}

View File

@ -13,7 +13,8 @@ service TPUProfiler {
message ProfileRequest {
// In future, the caller will be able to customize when profiling starts and
// stops. For now, it always collects 10 seconds worth of data.
// stops. For now, it collects `duration_ms` milliseconds worth of data.
uint64 duration_ms = 1;
// In future, the caller will indicate which TF session is being profiled, and
// only data relating to that program will be returned. For now, we assume

View File

@ -55,12 +55,16 @@ void Cluster::DisableOptimizer(bool disable) {
options_.config.mutable_graph_options()->mutable_rewrite_options();
rewriter_config->set_optimize_tensor_layout(false);
rewriter_config->set_disable_model_pruning(true);
rewriter_config->set_constant_folding(false);
rewriter_config->set_constant_folding(RewriterConfig::OFF);
rewriter_config->set_memory_optimization(RewriterConfig::NO_MEM_OPT);
rewriter_config->mutable_auto_parallel()->set_enable(false);
rewriter_config->clear_optimizers();
} else {
options->set_opt_level(OptimizerOptions::L1);
auto rewriter_config =
options_.config.mutable_graph_options()->mutable_rewrite_options();
rewriter_config->set_constant_folding(RewriterConfig::DEFAULT);
rewriter_config->set_memory_optimization(RewriterConfig::DEFAULT_MEM_OPT);
}
}

View File

@ -59,7 +59,7 @@ class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface {
class DeviceSimple : public DeviceBase {
public:
DeviceSimple() : DeviceBase(Env::Default()) {
eigen_worker_threads_.num_threads = 1;
eigen_worker_threads_.num_threads = port::NumSchedulableCPUs();
eigen_worker_threads_.workers = new thread::ThreadPool(
Env::Default(), "constant_folding", eigen_worker_threads_.num_threads);
eigen_threadpool_wrapper_.reset(
@ -101,6 +101,8 @@ string AsControlDependency(const NodeDef& node) {
} // namespace
ConstantFolding::ConstantFolding() {
resource_mgr_.reset(new ResourceMgr());
ops_to_preserve_ = std::regex(
"Placeholder.*|Const|.*Save.*|.*Restore.*|.*Reader|"
"Enter|RefEnter|Exit|RefExit|NextIteration|RefNextIteration|"
@ -346,6 +348,7 @@ Status ConstantFolding::EvaluateNode(const NodeDef& node,
params.frame_iter = FrameAndIter(0, 0);
params.inputs = &inputs;
params.op_kernel = op_kernel.get();
params.resource_manager = resource_mgr_.get();
gtl::InlinedVector<AllocatorAttributes, 4> output_attrs;
const int num_outputs = op_kernel->num_outputs();

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <regex>
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
#include "tensorflow/core/grappler/utils.h"
@ -70,6 +71,7 @@ class ConstantFolding : public GraphOptimizer {
Status SimplifyGraph(GraphDef* output, const GraphProperties& properties);
std::unique_ptr<DeviceBase> device_;
std::unique_ptr<ResourceMgr> resource_mgr_;
GraphDef graph_;
std::unique_ptr<NodeMap> node_map_;
std::set<string> nodes_to_preserve_;

View File

@ -238,6 +238,7 @@ class NodeProcessor {
*node->add_input() = input_name;
*node->add_input() = NHWCToNCHW ? kPermNHWCToNCHW : kPermNCHWToNHWC;
node->set_op("Transpose");
node->set_device(node_->device());
AttrValue attr_data_type;
attr_data_type.set_type(data_type);
node->mutable_attr()->insert({"T", attr_data_type});
@ -273,11 +274,10 @@ class NodeProcessor {
int output_pos = NodePosition(node_->input(pos));
TF_RETURN_IF_ERROR(HasAttribute(*node_, "T"));
TF_RETURN_IF_ERROR(HasAttribute(*input_node, "_output_shapes"));
NodeDef* transpose = AddNodeTranspose(
AddNodeTranspose(
node_name, node_->input(pos), node_->attr().at("T").type(),
input_node->attr().at("_output_shapes").list().shape(output_pos),
true);
transpose->set_device(node_->device());
node_map_->UpdateOutput(node_->input(pos), node_->name(), node_name);
node_map_->AddOutput(node_name, node_->name());
*node_->mutable_input(pos) = node_name;
@ -313,10 +313,9 @@ class NodeProcessor {
}
TF_RETURN_IF_ERROR(HasAttribute(*node_, "T"));
TF_RETURN_IF_ERROR(HasAttribute(*node_, "_output_shapes"));
NodeDef* transpose = AddNodeTranspose(
node_name, node_->name(), node_->attr().at("T").type(),
node_->attr().at("_output_shapes").list().shape(0), false);
transpose->set_device(node_->device());
AddNodeTranspose(node_name, node_->name(), node_->attr().at("T").type(),
node_->attr().at("_output_shapes").list().shape(0),
false);
*it = node_name;
node_map_->UpdateOutput(node_->name(), output->name(), node_name);
node_map_->AddOutput(node_name, output->name());
@ -604,6 +603,7 @@ class BinaryOpProcessor : public AgnosticNodeProcessor {
node_map_->AddNode(name, node);
node->set_name(name);
node->set_op("Const");
node->set_device(node_->device());
AttrValue attr_data_type;
attr_data_type.set_type(DT_INT32);
node->mutable_attr()->insert({"dtype", attr_data_type});
@ -628,6 +628,7 @@ class BinaryOpProcessor : public AgnosticNodeProcessor {
*node->add_input() = input_name;
*node->add_input() = shape_const_node_name;
node->set_op("Reshape");
node->set_device(node_->device());
AttrValue attr_type_indices;
attr_type_indices.set_type(DT_INT32);
@ -650,13 +651,10 @@ class BinaryOpProcessor : public AgnosticNodeProcessor {
TF_RETURN_IF_ERROR(HasAttribute(*input_node, "_output_shapes"));
int vector_size =
input_node->attr().at("_output_shapes").list().shape(0).dim(0).size();
NodeDef* shp = AddNodeShapeConst(shape_const_node_name, vector_size);
shp->set_device("/job:localhost/replica:0/task:0/cpu:0");
AddNodeShapeConst(shape_const_node_name, vector_size);
TF_RETURN_IF_ERROR(HasAttribute(*node_, "T"));
NodeDef* reshape =
AddNodeReshape(reshape_node_name, node_->input(1),
shape_const_node_name, node_->attr().at("T").type());
reshape->set_device(node_->device());
AddNodeReshape(reshape_node_name, node_->input(1), shape_const_node_name,
node_->attr().at("T").type());
node_map_->AddOutput(shape_const_node_name, reshape_node_name);
node_map_->UpdateOutput(node_->input(1), node_->name(),
reshape_node_name);
@ -953,8 +951,12 @@ struct TuningConfig {
class DataLayoutOptimizer {
public:
explicit DataLayoutOptimizer(GraphDef* graph, TuningConfig config)
: graph_(graph), node_map_(graph_), config_(config) {}
explicit DataLayoutOptimizer(const string& default_device, GraphDef* graph,
TuningConfig config)
: default_device_(default_device),
graph_(graph),
node_map_(graph_),
config_(config) {}
Status Optimize() {
LOG(INFO) << "Number of nodes for original graph: " << graph_->node_size();
@ -972,6 +974,7 @@ class DataLayoutOptimizer {
node_map_.AddNode(name, node);
node->set_name(name);
node->set_op("Const");
node->set_device(default_device_);
AttrValue attr_data_type;
attr_data_type.set_type(DT_INT32);
node->mutable_attr()->insert({"dtype", attr_data_type});
@ -990,6 +993,7 @@ class DataLayoutOptimizer {
node_map_.AddNode(name, node);
node->set_name(name);
node->set_op("Const");
node->set_device(default_device_);
AttrValue attr_data_type;
attr_data_type.set_type(dtype);
node->mutable_attr()->insert({"dtype", attr_data_type});
@ -1014,6 +1018,7 @@ class DataLayoutOptimizer {
node_map_.AddNode(kReductionConst, node);
node->set_name(kReductionConst);
node->set_op("Const");
node->set_device(default_device_);
AttrValue attr_data_type;
attr_data_type.set_type(DT_INT32);
node->mutable_attr()->insert({"dtype", attr_data_type});
@ -1072,15 +1077,10 @@ class DataLayoutOptimizer {
// expanded.
if (graph_->node_size() > node_size_original) {
NodeDef* n = AddNodePermConst(kPermNHWCToNCHW, {0, 3, 1, 2});
n->set_device("/job:localhost/replica:0/task:0/cpu:0");
n = AddNodePermConst(kPermNCHWToNHWC, {0, 2, 3, 1});
n->set_device("/job:localhost/replica:0/task:0/cpu:0");
n = AddNodeConcatConst();
n->set_device("/job:localhost/replica:0/task:0/cpu:0");
n = AddGatherAxisConst();
n->set_device("/job:localhost/replica:0/task:0/cpu:0");
n = AddNodeReductionConst();
n->set_device("/job:localhost/replica:0/task:0/cpu:0");
std::set<string> ops_format_agnostic = GetOpsFormatAgnostic();
for (int i = 0; i < graph_->node_size(); i++) {
if (ops_format_agnostic.find(graph_->node(i).op()) !=
@ -1169,6 +1169,7 @@ class DataLayoutOptimizer {
return Status::OK();
}
string default_device_;
GraphDef* graph_;
NodeMap node_map_;
TuningConfig config_;
@ -1221,16 +1222,24 @@ Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
*output = new_item.graph;
TuningConfig config;
config.no_gemm = false;
DataLayoutOptimizer layout_optimizer(output, config);
status = layout_optimizer.Optimize();
string default_device = "/job:localhost/replica:0/task:0/cpu:0";
if (cluster) {
if (!cluster->GetDevices().empty()) {
default_device = cluster->GetDevices().begin()->first;
}
}
std::unique_ptr<DataLayoutOptimizer> layout_optimizer(
new DataLayoutOptimizer(default_device, output, config));
status = layout_optimizer->Optimize();
// This is based on an empirical observation that if the introduced Transpose
// nodes is more than 30, not using GEMM implementation would result in better
// performance.
if (status.ok() && GetNumTranspose(*output) > 30) {
*output = new_item.graph;
config.no_gemm = true;
DataLayoutOptimizer layout_optimizer(output, config);
status = layout_optimizer.Optimize();
layout_optimizer.reset(
new DataLayoutOptimizer(default_device, output, config));
status = layout_optimizer->Optimize();
}
if (!status.ok()) {

View File

@ -447,7 +447,7 @@ void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,
(cheap_to_recompute_ops.count(node.op()) > 0 ||
node.attr().count(kRecomputeHint) > 0);
});
} else { // optimization_level == RewriterConfig::MANUAL
} else if (optimization_level == RewriterConfig::MANUAL) {
recomputed_subgraphs =
GetOpGroupsToRecompute(graph, node_map, [&feeds](const NodeDef& node) {
return !IsTargetOp(node) && feeds.count(node.name()) == 0 &&

View File

@ -58,7 +58,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
if (!cfg_.disable_model_pruning()) {
optimizers.push_back(std::unique_ptr<GraphOptimizer>(new ModelPruner()));
}
if (cfg_.constant_folding()) {
if (cfg_.constant_folding() == RewriterConfig::ON) {
optimizers.push_back(
std::unique_ptr<GraphOptimizer>(new ConstantFolding()));
}
@ -66,7 +66,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
optimizers.push_back(
std::unique_ptr<GraphOptimizer>(new LayoutOptimizer()));
}
if (cfg_.memory_optimization() > 0) {
if (cfg_.memory_optimization() > 1) {
optimizers.push_back(std::unique_ptr<GraphOptimizer>(
new MemoryOptimizer(cfg_.memory_optimization())));
}
@ -121,8 +121,9 @@ void MetaOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item,
bool MetaOptimizerEnabled(const RewriterConfig& cfg) {
return !cfg.disable_model_pruning() || cfg.optimize_tensor_layout() ||
cfg.constant_folding() || cfg.auto_parallel().enable() ||
cfg.memory_optimization() > 0 || !cfg.optimizers().empty();
cfg.constant_folding() == RewriterConfig::ON ||
cfg.auto_parallel().enable() || cfg.memory_optimization() > 1 ||
!cfg.optimizers().empty();
}
Status RunMetaOptimizer(const GrapplerItem& item, const RewriterConfig& cfg,

View File

@ -272,87 +272,6 @@ __global__ void SwapDimension1And2InTensor3UsingTiles(const T* input,
}
}
// Use shared memory tiles to swap dimension-1 and dimension-2 of a 3D tensor
// when only one of the dimension sizes is smaller than 16,
// where dimensions are zero-based: output[i][j][k] = input[i][k][j].
//
// small_dim = the_smaller_dimension_size
// large_dim = the_larger_dimension_size
// tile_num_per_block = blockDim.x
// kTileLength = small_dim
//
// Each thread block operates on a single rectangle tile, where its width is
// kTileLength (we currently set it to 64) and its height is small_dim,
// We set the thread block's X dimension to be tile_num_per_block, and its Y
// and Z to be one.
template <typename T, bool SmallDim2>
__global__ void SwapDimension1And2InTensor3SmallDim(const T* input,
int batch_per_block,
Dimension<3> input_dims,
T* output) {
// TODO(yangzihao) avoid share memory bank conflict.
extern __shared__ __align__(sizeof(T)) unsigned char shmem[];
T* shared_memory_tile = reinterpret_cast<T*>(shmem);
eigen_assert(blockDim.y == 1);
eigen_assert(blockDim.z == 1);
eigen_assert(gridDim.z == 1);
int block_offset = blockIdx.x * blockDim.x * batch_per_block;
int x = threadIdx.x;
int tile_height = blockDim.x;
// Get tile height, width, and thread/block origin indices.
int small_dim = SmallDim2 ? input_dims[2] : input_dims[1];
int large_dim = SmallDim2 ? input_dims[1] : input_dims[2];
int global_offset = small_dim * large_dim * (blockIdx.y * batch_per_block) +
(SmallDim2 ? block_offset * small_dim : block_offset) + x;
if (global_offset > (input_dims[0] * input_dims[1] * input_dims[2])) return;
for (int batch = 0; batch < batch_per_block; ++batch) {
int block_origin_idx =
small_dim * large_dim * (blockIdx.y * batch_per_block + batch);
int thread_origin_idx =
block_origin_idx +
(SmallDim2 ? block_offset * small_dim : block_offset) + x;
if (block_offset + blockDim.x > large_dim) {
tile_height = large_dim - block_offset;
}
// Load a continuous memory region to shared memory tile.
if (x < tile_height) {
for (int y = 0; y < small_dim; y++) {
int shmem_index =
SmallDim2 ? (x + y * tile_height) : (x * small_dim + y);
shared_memory_tile[shmem_index] =
ldg(input + thread_origin_idx +
y * (SmallDim2 ? tile_height : large_dim));
}
}
__syncthreads();
// Get block origin index for output array.
int output_block_offset = block_origin_idx;
int output_block_idx = SmallDim2 ? block_offset : block_offset * small_dim;
int output_block_origin_idx = output_block_offset + output_block_idx;
// Store the tranposed memory region in shared memory to device.
if (x < tile_height) {
for (int y = 0; y < small_dim; y++) {
int output_idx = output_block_origin_idx + x +
y * (SmallDim2 ? large_dim : tile_height);
int shmem_index =
SmallDim2 ? (x * small_dim + y) : (x + y * tile_height);
output[output_idx] = shared_memory_tile[shmem_index];
}
}
}
}
// A Cuda custom kernel that convert input to output, given proper padding on
// the left and the top. The padded value is zero.
template <typename T, int NDIMS>
@ -501,63 +420,25 @@ template <typename T>
void RunSwapDimension1And2InTensor3(const GPUDevice& d, const T* input,
const Dimension<3>& input_dims, T* output) {
// If both dimensions are not trivial, use tiles for the actual swapping.
// If one dimension is trivial, use SmallDim kernel for swapping.
// Otherwise, the trivial swapping relying on the ldg cache is more efficient.
static const int kMinDimensionToUseTiles = 16;
bool use_tiles = (input_dims[1] >= kMinDimensionToUseTiles &&
input_dims[2] >= kMinDimensionToUseTiles);
bool use_small_dim = ((input_dims[1] >= kMinDimensionToUseTiles &&
input_dims[2] < kMinDimensionToUseTiles)) ||
((input_dims[1] < kMinDimensionToUseTiles &&
input_dims[2] >= kMinDimensionToUseTiles));
static const int NumSubTiles = 8;
if (use_tiles) {
// We get best performance when TileSize is the number of threads in a warp
// (32 on our GPUs) and NumSubTiles is 8, so our block size is 8 * 32 = 256
// threads.
static const int TileSize = 32;
static const int NumSubTiles = 8;
Dimension<3> input_dims_in_tiles = {
input_dims[0], (input_dims[1] + TileSize - 1) / TileSize,
(input_dims[2] + TileSize - 1) / TileSize,
};
int total_tiles_count = input_dims_in_tiles[0] * input_dims_in_tiles[1] *
input_dims_in_tiles[2];
// We get best performance when TileSize is the number of threads in a warp
// (32 on our GPUs) and NumSubTiles is 8, so our block size is 8 * 32 = 256
// threads.
SwapDimension1And2InTensor3UsingTiles<T, TileSize, NumSubTiles><<<
total_tiles_count, dim3(TileSize, NumSubTiles), 0, d.stream()>>>(
input, input_dims, output);
} else if (use_small_dim) {
// When only one of the dimensions is smaller than kMinDimensionToUseTiles,
// we use one block to process a rectangle region with the size of
// kTileLength * small_dim. We found that when set kTileLength to 64 on
// TitanX Maxwell GPU, it achieves the best performance.
// large_dim
// +---------------...--------+
// | | | |
// small_dim | | ... | |
// | | | |
// +--------------...---------+
// \----- ------/ \- -/
// V V
// kTileLength(tile_height) tile_height
static const int kTileLength = 64;
static const int kGridDimY = 2048;
int small_dim = std::min(input_dims[2], input_dims[1]);
int large_dim = std::max(input_dims[2], input_dims[1]);
int tile_num_per_block = (large_dim + kTileLength - 1) / kTileLength;
int grid_dim_y = input_dims[0] < kGridDimY ? input_dims[0] : kGridDimY;
int batch_per_block = (input_dims[0] + grid_dim_y - 1) / grid_dim_y;
if (input_dims[2] < input_dims[1]) {
SwapDimension1And2InTensor3SmallDim<T, true>
<<<dim3(tile_num_per_block, grid_dim_y), kTileLength,
kTileLength * small_dim * sizeof(T), d.stream()>>>(
input, batch_per_block, input_dims, output);
} else {
SwapDimension1And2InTensor3SmallDim<T, false>
<<<dim3(tile_num_per_block, grid_dim_y), kTileLength,
kTileLength * small_dim * sizeof(T), d.stream()>>>(
input, batch_per_block, input_dims, output);
}
} else {
int total_element_count = input_dims[0] * input_dims[1] * input_dims[2];
CudaLaunchConfig config = GetCudaLaunchConfig(total_element_count, d);

View File

@ -437,7 +437,7 @@ TEST(GraphTransferer,
<< "with quantized input";
CheckHexagonControllerVersion();
const IGraphTransferOpsDefinitions* ops_definitions =
const IRemoteFusedGraphOpsDefinitions* ops_definitions =
&HexagonOpsDefinitions::getInstance();
std::vector<std::pair<string, Tensor>> inputs;
inputs.emplace_back("Mul", Tensor(DT_QUINT8, {1, WIDTH, HEIGHT, DEPTH}));

View File

@ -77,6 +77,13 @@ struct hash<StringPiece> {
}
};
template <typename T, typename U>
struct hash<std::pair<T, U>> {
size_t operator()(const std::pair<T, U>& p) const {
return Hash64Combine(hash<T>()(p.first), hash<U>()(p.second));
}
};
} // namespace tensorflow
#endif // TENSORFLOW_LIB_HASH_HASH_H_

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <sys/stat.h>
#include <deque>
#include <utility>
#include <vector>
@ -30,7 +31,10 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/env_time.h"
#include "tensorflow/core/platform/host_info.h"
#include "tensorflow/core/platform/protobuf.h"
namespace tensorflow {
@ -273,6 +277,39 @@ string Env::GetExecutablePath() {
return exe_path;
}
bool Env::LocalTempFilename(string* filename) {
std::vector<string> dirs;
GetLocalTempDirectories(&dirs);
// Try each directory, as they might be full, have inappropriate
// permissions or have different problems at times.
for (const string& dir : dirs) {
#ifdef __APPLE__
uint64_t tid64;
pthread_threadid_np(nullptr, &tid64);
int32 tid = static_cast<int32>(tid64);
int32 pid = static_cast<int32>(getpid());
#elif defined(PLATFORM_WINDOWS)
int32 tid = static_cast<int32>(GetCurrentThreadId());
int32 pid = static_cast<int32>(GetCurrentProcessId());
#else
int32 tid = static_cast<int32>(pthread_self());
int32 pid = static_cast<int32>(getpid());
#endif
uint64 now_microsec = NowMicros();
*filename = io::JoinPath(
dir, strings::Printf("tempfile-%s-%x-%d-%llx", port::Hostname().c_str(),
tid, pid, now_microsec));
if (FileExists(*filename).ok()) {
filename->clear();
} else {
return true;
}
}
return false;
}
Thread::~Thread() {}
EnvWrapper::~EnvWrapper() {}

View File

@ -215,6 +215,9 @@ class Env {
/// symlinks if there is any.
string GetExecutablePath();
/// Creates a local unique temporary file name. Returns true if success.
bool LocalTempFilename(string* filename);
// TODO(jeff,sanjay): Add back thread/thread-pool support if needed.
// TODO(jeff,sanjay): if needed, tighten spec so relative to epoch, or
// provide a routine to get the absolute time.
@ -279,6 +282,9 @@ class Env {
const string& version) = 0;
private:
// Returns a possible list of local temporary directories.
void GetLocalTempDirectories(std::vector<string>* list);
std::unique_ptr<FileSystemRegistry> file_system_registry_;
TF_DISALLOW_COPY_AND_ASSIGN(Env);
EnvTime* envTime = EnvTime::Default();

View File

@ -298,4 +298,32 @@ TEST_F(DefaultEnvTest, GetExecutablePath) {
TF_EXPECT_OK(env->FileExists(env->GetExecutablePath()));
}
TEST_F(DefaultEnvTest, LocalTempFilename) {
Env* env = Env::Default();
string filename;
EXPECT_TRUE(env->LocalTempFilename(&filename));
EXPECT_FALSE(env->FileExists(filename).ok());
// Write something to the temporary file.
std::unique_ptr<WritableFile> file_to_write;
TF_CHECK_OK(env->NewWritableFile(filename, &file_to_write));
TF_CHECK_OK(file_to_write->Append("Null"));
TF_CHECK_OK(file_to_write->Close());
TF_CHECK_OK(env->FileExists(filename));
// Read from the temporary file and check content.
std::unique_ptr<RandomAccessFile> file_to_read;
TF_CHECK_OK(env->NewRandomAccessFile(filename, &file_to_read));
StringPiece content;
char scratch[1024];
CHECK_EQ(error::OUT_OF_RANGE,
file_to_read->Read(0 /* offset */, 1024 /* n */, &content, scratch)
.code());
EXPECT_EQ("Null", content.ToString());
// Delete the temporary file.
TF_CHECK_OK(env->DeleteFile(filename));
EXPECT_FALSE(env->FileExists(filename).ok());
}
} // namespace tensorflow

View File

@ -131,4 +131,39 @@ Env* Env::Default() {
}
#endif
void Env::GetLocalTempDirectories(std::vector<string>* list) {
list->clear();
// Directories, in order of preference. If we find a dir that
// exists, we stop adding other less-preferred dirs
const char* candidates[] = {
// Non-null only during unittest/regtest
getenv("TEST_TMPDIR"),
// Explicitly-supplied temp dirs
getenv("TMPDIR"),
getenv("TMP"),
// If all else fails
"/tmp",
};
for (const char* d : candidates) {
if (!d || d[0] == '\0') continue; // Empty env var
// Make sure we don't surprise anyone who's expecting a '/'
string dstr = d;
if (dstr[dstr.size() - 1] != '/') {
dstr += "/";
}
struct stat statbuf;
if (!stat(d, &statbuf) && S_ISDIR(statbuf.st_mode) &&
!access(dstr.c_str(), 0)) {
// We found a dir that exists and is accessible - we're done.
list->push_back(dstr);
return;
}
}
}
} // namespace tensorflow

View File

@ -172,4 +172,21 @@ Env* Env::Default() {
return default_env;
}
void Env::GetLocalTempDirectories(std::vector<string>* list) {
list->clear();
// On windows we'll try to find a directory in this order:
// C:/Documents & Settings/whomever/TEMP (or whatever GetTempPath() is)
// C:/TMP/
// C:/TEMP/
// C:/WINDOWS/ or C:/WINNT/
// .
char tmp[MAX_PATH];
// GetTempPath can fail with either 0 or with a space requirement > bufsize.
// See http://msdn.microsoft.com/en-us/library/aa364992(v=vs.85).aspx
DWORD n = GetTempPathA(MAX_PATH, tmp);
if (n > 0 && n <= MAX_PATH) list->push_back(tmp);
list->push_back("C:\\tmp\\");
list->push_back("C:\\temp\\");
}
} // namespace tensorflow

View File

@ -19,18 +19,30 @@ message RewriterConfig {
// configuration options do not apply to explicitly triggered optimization
// passes in the optimizers field.
enum Toggle {
DEFAULT = 0;
ON = 1;
OFF = 2;
}
// Optimize tensor layouts
bool optimize_tensor_layout = 1;
// Fold constants (default is OFF)
Toggle constant_folding = 3;
// If true, don't remove unecessary ops from the graph
bool disable_model_pruning = 2;
bool constant_folding = 3;
enum MemOptType {
// The default setting (currently disabled)
DEFAULT_MEM_OPT = 0;
// Disabled in the meta-optimizer.
NO_MEM_OPT = 0;
NO_MEM_OPT = 1;
// Driven by manual op-level annotations.
MANUAL = 1;
MANUAL = 2;
// Driven by heuristics. The behavior of these heuristics is subject to
// change. Currently includes an experimental recomputation heuristic.
HEURISTICS = 2;
HEURISTICS = 3;
}
// Configures memory optimization passes through the meta-optimizer. Has no
// effect on manually requested memory optimization passes in the optimizers

View File

@ -186,9 +186,8 @@ class Session {
/// the `SessionOptions::target` field).
virtual Status Close() = 0;
// NOTE(ashankar): As of July 2017, this is was a method added to
// faciliate some experimentation. Reconsider/re-evaluate after
// September 2017.
// NOTE(ashankar): As of July 2017, this method was added to faciliate some
// experimentation. Reconsider/re-evaluate after September 2017.
//
// Sets `*output` to the `DeviceMgr` that owns accessible devices in the
// address-space of the caller.

View File

@ -3799,26 +3799,6 @@ cuda_py_test(
main = "ops/concat_benchmark.py",
)
cuda_py_test(
name = "conv2d_benchmark",
size = "large",
srcs = ["ops/conv2d_benchmark.py"],
additional_deps = [
":client",
":client_testlib",
":control_flow_ops",
":framework_for_generated_wrappers",
":nn_ops",
":platform",
":platform_benchmark",
":random_ops",
":variables",
"//third_party/py/numpy",
"//tensorflow/core:protos_all_py",
],
main = "ops/conv2d_benchmark.py",
)
cuda_py_test(
name = "split_benchmark",
srcs = ["ops/split_benchmark.py"],

View File

@ -229,26 +229,6 @@ class TransposeTest(test.TestCase):
self.assertAllEqual(np_ans, tf_ans)
self.assertShapeEqual(np_ans, y)
def testLargeSizeGPU(self):
# If no GPU available, skip the test
if not test.is_gpu_available(cuda_only=True):
return
large_shapes = [[1000000, 31, 3], [3, 1000000, 31], [3, 31, 1000000],
[2, 1000, 1000], [1000, 2, 1000], [1000, 1000, 2]]
perms = [[0, 2, 1]] * 6
for input_shape, perm in zip(large_shapes, perms):
total_size = np.prod(input_shape)
inp = np.arange(1, total_size + 1, dtype=np.float32).reshape(input_shape)
np_ans = self._np_transpose(inp, perm)
with self.test_session(use_gpu=True):
inx = ops.convert_to_tensor(inp)
y = array_ops.transpose(inx, perm)
tf_ans = y.eval()
self.assertAllEqual(np_ans, tf_ans)
self.assertShapeEqual(np_ans, y)
def testNop(self):
self._compareCpu(np.arange(0, 6).reshape([3, 2]).astype(np.float32), [0, 1])

View File

@ -1,141 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Benchmark for Conv2D op."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
import time
from tensorflow.python.client import session as session_lib
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
def build_graph(device, input_shape, filter_shape, strides, padding, num_iters):
"""builds a graph containing a sequence of conv2d operations.
Args:
device: String, the device to run on.
input_shape: Shape of the input tensor.
filter_shape: Shape of the filter tensor.
strides: A list of ints. 1-D of length 4. The stride of sliding
window for each dimension of input.
padding: A string from: "SAME", "VALID". The type of padding
algorithm to use.
num_iters: number of iterations to run conv2d.
Returns:
An array of tensors to run()
"""
with ops.device("/%s:0" % device):
inp = variables.Variable(random_ops.truncated_normal(input_shape))
filt = variables.Variable(random_ops.truncated_normal(filter_shape))
outputs = []
conv2d_op = nn_ops.conv2d(inp, filt, strides, padding, data_format="NHWC")
outputs.append(conv2d_op)
for _ in range(1, num_iters):
with ops.control_dependencies([conv2d_op]):
conv2d_op = nn_ops.conv2d(
inp, filt, strides, padding, data_format="NHWC")
outputs.append(conv2d_op)
return control_flow_ops.group(*outputs)
class Conv2DBenchmark(test.Benchmark):
"""Benchmark conv2d!"""
def _run_graph(self, device, input_shape, filter_shape, strides, padding,
num_iters):
"""runs the graph and print its execution time.
Args:
device: String, the device to run on.
input_shape: Shape of the input tensor.
filter_shape: Shape of the filter tensor.
strides: A list of ints. 1-D of length 4. The stride of sliding
window for each dimension of input.
padding: A string from: "SAME", "VALID". The type of padding
algorithm to use. num_iters: Number of iterations to run the
benchmark.
num_iters: number of iterations to run conv2d.
Returns:
The duration of the run in seconds.
"""
graph = ops.Graph()
with graph.as_default():
outputs = build_graph(device, input_shape, filter_shape, strides, padding,
num_iters)
with session_lib.Session(graph=graph) as session:
variables.global_variables_initializer().run()
# warmup runs
session.run(outputs)
start_time = time.time()
session.run(outputs)
duration = (time.time() - start_time) / num_iters
print("%s inputshape:%s filtershape:%s strides:%s padding:%s "
"%d iters: %.8f sec" %
(device, str(input_shape).replace(" ", ""),
str(filter_shape).replace(" ", ""),
str(strides).replace(" ", ""), padding, num_iters, duration))
name_template = (
"conv2d_{device}_input_shape_{inputshape}_filter_shape_{filtershape}_"
"strides_{strides}_padding_{padding}")
self.report_benchmark(
name=name_template.format(
device=device,
inputshape=str(input_shape).replace(" ", ""),
filtershape=str(filter_shape).replace(" ", ""),
strides=str(strides).replace(" ", ""),
padding=padding).replace(" ", ""),
iters=num_iters,
wall_time=duration / num_iters)
return duration
def benchmark_conv2d(self):
print("conv2d benchmark:")
h = 500
w = 500
fh = 3
fw = 3
input_shapes = []
filter_shapes = []
for b, c in itertools.product([4, 16, 32], [i for i in range(3, 16)]):
input_shapes += [[b, h, w, c]]
filter_shapes += [[fh, fw, c, b]]
strides = [[1, 2, 2, 1]]
paddings = ["VALID", "SAME"]
for ishape, fshape in zip(input_shapes, filter_shapes):
for stride in strides:
for padding in paddings:
self._run_graph("gpu", ishape, fshape, stride, padding, 80)
if __name__ == "__main__":
test.main()

View File

@ -1,4 +1,4 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -32,7 +32,7 @@ from tensorflow.python.platform import test
def build_graph(device, input_shape, perm, datatype, num_iters):
"""builds a graph containing a sequence of conv2d operations.
"""Build a graph containing a sequence of conv2d operations.
Args:
device: String, the device to run on.
@ -50,12 +50,10 @@ def build_graph(device, input_shape, perm, datatype, num_iters):
t = constant_op.constant(inp, shape=input_shape)
outputs = []
transpose_op = array_ops.transpose(t, perm)
outputs.append(transpose_op)
for _ in range(1, num_iters):
with ops.control_dependencies([transpose_op]):
transpose_op = array_ops.transpose(t, perm)
outputs.append(transpose_op)
outputs.append(array_ops.transpose(t, perm))
for i in range(1, num_iters):
with ops.control_dependencies([outputs[i - 1]]):
outputs.append(array_ops.transpose(t, perm))
return control_flow_ops.group(*outputs)
@ -63,7 +61,7 @@ class TransposeBenchmark(test.Benchmark):
"""Benchmark transpose!"""
def _run_graph(self, device, input_shape, perm, num_iters, datatype):
"""runs the graph and print its execution time.
"""Run the graph and print its execution time.
Args:
device: String, the device to run on.
@ -84,11 +82,9 @@ class TransposeBenchmark(test.Benchmark):
session.run(outputs)
start_time = time.time()
session.run(outputs)
duration = (time.time() - start_time) / num_iters
throughput = np.prod(
np.array(input_shape)) * datatype().itemsize * 2 / duration / 1e9
print("%s %s inputshape:%s perm:%s %d %.6fsec, %.4fGB/s." %
(device, str(datatype), str(input_shape).replace(" ", ""),
str(perm).replace(" ", ""), num_iters, duration, throughput))
@ -112,19 +108,19 @@ class TransposeBenchmark(test.Benchmark):
datatypes = [np.complex128, np.float64, np.float32, np.float16, np.int8]
small_shapes = [[2, 20, 20, 20, 16], [2, 16, 20, 20, 20]] * 2
small_shapes += [[2, 100, 100, 16], [2, 16, 100, 100]] * 2
small_shapes += [[2, 5000, 16], [2, 16, 5000]] * 2
small_perms = [[0, 4, 1, 2, 3], [0, 2, 3, 4, 1]] + [[4, 1, 2, 3, 0]] * 2
small_perms += [[0, 3, 1, 2], [0, 2, 3, 1]] + [[3, 1, 2, 0]] * 2
small_perms += [[0, 2, 1]] * 2 + [[2, 1, 0]] * 2
small_shapes = [[2, 20, 20, 20, 16], [2, 16, 20, 20, 20]] * 2 + [[
2, 100, 100, 16
], [2, 16, 100, 100]] * 2 + [[2, 5000, 16], [2, 16, 5000]] * 2
small_perms = [[0, 4, 1, 2, 3], [0, 2, 3, 4, 1]] + [[4, 1, 2, 3, 0]] * 2 + [
[0, 3, 1, 2], [0, 2, 3, 1]
] + [[3, 1, 2, 0]] * 2 + [[0, 2, 1]] * 2 + [[2, 1, 0]] * 2
large_shapes = [[2, 100, 100, 100, 32], [2, 100, 100, 100, 64]] * 2
large_shapes += [[2, 1000, 1000, 32], [2, 1000, 1000, 64]] * 2
large_shapes += [[2, 1000000, 32], [2, 1000000, 64]] * 2
large_perms = [[0, 4, 1, 2, 3], [0, 2, 3, 4, 1]] + [[4, 1, 2, 3, 0]] * 2
large_perms += [[0, 3, 1, 2], [0, 2, 3, 1]] + [[3, 1, 2, 0]] * 2
large_perms += [[0, 2, 1]] * 2 + [[2, 1, 0]] * 2
large_shapes = [[2, 100, 100, 100, 32], [2, 100, 100, 100, 64]] * 2 + [[
2, 1000, 1000, 32
], [2, 1000, 1000, 64]] * 2 + [[2, 1000000, 32], [2, 1000000, 64]] * 2
large_perms = [[0, 4, 1, 2, 3], [0, 2, 3, 4, 1]] + [[4, 1, 2, 3, 0]] * 2 + [
[0, 3, 1, 2], [0, 2, 3, 1]
] + [[3, 1, 2, 0]] * 2 + [[0, 2, 1]] * 2 + [[2, 1, 0]] * 2
huge_shapes = [[2, 100, 100, 100, 128], [2, 1000, 1000, 128],
[2, 1000000, 128]] * 2
@ -147,23 +143,5 @@ class TransposeBenchmark(test.Benchmark):
for ishape, perm in zip(huge_shapes, huge_perms):
self._run_graph("gpu", ishape, perm, num_iters, datatype)
small_dim_large_shapes = [[2, 1000000, 3], [2, 3, 1000000], [2, 1000000, 8],
[2, 8, 1000000]]
small_dim_small_shapes = [[2, 5000, 3], [2, 3, 5000], [2, 5000, 8],
[2, 8, 5000]]
small_dim_perms = [[0, 2, 1]] * 4
num_iters = 320
small_dim_large_shape_datatypes = [np.float64, np.float32, np.int8]
for datatype in small_dim_large_shape_datatypes:
for ishape, perm in zip(small_dim_large_shapes, small_dim_perms):
self._run_graph("gpu", ishape, perm, num_iters, datatype)
small_dim_small_shape_datatypes = [np.complex128, np.float16]
for datatype in small_dim_small_shape_datatypes:
for ishape, perm in zip(small_dim_small_shapes, small_dim_perms):
self._run_graph("gpu", ishape, perm, num_iters, datatype)
if __name__ == "__main__":
test.main()

View File

@ -492,8 +492,15 @@ def _store_sparse_tensors(tensor_list, enqueue_many, keep_input,
lambda: -1 * array_ops.ones(array_ops.shape(t)[0:1], dtypes.int64))
out_tensor.set_shape([None]) # necessary when t.ndims is unknown
return out_tensor
def _sparse_values_to_keep(t, keep_input):
"""Convert a per-row `keep_input` vector to a per-value one."""
# Get the rows of every value in the sparse Tensor.
row_values = array_ops.reshape(
t.indices, [array_ops.shape(t.indices)[0], -1])[:, 0]
# The value should be kept iff the row should be kept.
return array_ops.gather(keep_input, row_values)
if keep_input.shape.ndims == 1:
t = sparse_ops.sparse_retain(t, keep_input)
t = sparse_ops.sparse_retain(t, _sparse_values_to_keep(t, keep_input))
store_f = lambda t, name, _: _store_many_sparse(t, shared_name=name)
elif enqueue_many:
store_f = _maybe_store_many_sparse
@ -877,7 +884,7 @@ def batch(tensors, batch_size, num_threads=1, capacity=32,
`batch_size` is returned when the queue is closed and there are not enough
elements to fill the batch, otherwise the pending elements are discarded.
In addition, all output tensors' static shapes, as accessed via the
`get_shape` method will have a first `Dimension` value of `None`, and
`shape` property will have a first `Dimension` value of `None`, and
operations that depend on fixed batch_size would fail.
Args:
@ -1031,7 +1038,7 @@ def batch_join(tensors_list, batch_size, capacity=32, enqueue_many=False,
`batch_size` is returned when the queue is closed and there are not enough
elements to fill the batch, otherwise the pending elements are discarded.
In addition, all output tensors' static shapes, as accessed via the
`get_shape` method will have a first `Dimension` value of `None`, and
`shape` property will have a first `Dimension` value of `None`, and
operations that depend on fixed batch_size would fail.
Args:
@ -1086,8 +1093,8 @@ def maybe_batch_join(tensors_list, keep_input, batch_size, capacity=32,
added to the queue or not. If it is a scalar and evaluates `True`, then
`tensors` are all added to the queue. If it is a vector and `enqueue_many`
is `True`, then each example is added to the queue only if the
corresponding value in `keep_input` is `True`. This tensor essentially acts
as a filtering mechanism.
corresponding value in `keep_input` is `True`. This tensor essentially
acts as a filtering mechanism.
batch_size: An integer. The new batch size pulled from the queue.
capacity: An integer. The maximum number of elements in the queue.
enqueue_many: Whether each tensor in `tensor_list_list` is a single
@ -1176,7 +1183,7 @@ def shuffle_batch(tensors, batch_size, capacity, min_after_dequeue,
`batch_size` is returned when the queue is closed and there are not enough
elements to fill the batch, otherwise the pending elements are discarded.
In addition, all output tensors' static shapes, as accessed via the
`get_shape` method will have a first `Dimension` value of `None`, and
`shape` property will have a first `Dimension` value of `None`, and
operations that depend on fixed batch_size would fail.
Args:
@ -1237,8 +1244,8 @@ def maybe_shuffle_batch(tensors, batch_size, capacity, min_after_dequeue,
added to the queue or not. If it is a scalar and evaluates `True`, then
`tensors` are all added to the queue. If it is a vector and `enqueue_many`
is `True`, then each example is added to the queue only if the
corresponding value in `keep_input` is `True`. This tensor essentially acts
as a filtering mechanism.
corresponding value in `keep_input` is `True`. This tensor essentially
acts as a filtering mechanism.
num_threads: The number of threads enqueuing `tensor_list`.
seed: Seed for the random shuffling within the queue.
enqueue_many: Whether each tensor in `tensor_list` is a single example.
@ -1318,7 +1325,7 @@ def shuffle_batch_join(tensors_list, batch_size, capacity,
`batch_size` is returned when the queue is closed and there are not enough
elements to fill the batch, otherwise the pending elements are discarded.
In addition, all output tensors' static shapes, as accessed via the
`get_shape` method will have a first `Dimension` value of `None`, and
`shape` property will have a first `Dimension` value of `None`, and
operations that depend on fixed batch_size would fail.
Args:
@ -1379,8 +1386,8 @@ def maybe_shuffle_batch_join(tensors_list, batch_size, capacity,
added to the queue or not. If it is a scalar and evaluates `True`, then
`tensors` are all added to the queue. If it is a vector and `enqueue_many`
is `True`, then each example is added to the queue only if the
corresponding value in `keep_input` is `True`. This tensor essentially acts
as a filtering mechanism.
corresponding value in `keep_input` is `True`. This tensor essentially
acts as a filtering mechanism.
seed: Seed for the random shuffling within the queue.
enqueue_many: Whether each tensor in `tensor_list_list` is a single
example.

View File

@ -903,6 +903,29 @@ class BatchTest(test_lib.TestCase):
[sparse], keep_input=[True, False], batch_size=2, enqueue_many=True)
self.assertIs(None, batched.dense_shape.get_shape().num_elements())
def testMaybeBatchCorrectValues(self):
sparse_t = sparse_tensor.SparseTensor(
indices=[[0, 1], [0, 2], [1, 0], [1, 3]],
dense_shape=[2, 4],
values=[5, 4, 7, 2])
keep = constant_op.constant([True, False])
batched = inp.maybe_batch(
[sparse_t], keep_input=keep, batch_size=1, enqueue_many=True)
with self.test_session():
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(coord=coord)
batched_np = batched.eval()
coord.request_stop()
for thread in threads:
thread.join()
self.assertAllEqual([[0, 1], [0, 2]], batched_np.indices)
self.assertAllEqual([5, 4], batched_np.values)
self.assertAllEqual([1, 4], batched_np.dense_shape)
class BatchJoinTest(test_lib.TestCase):
@ -1457,6 +1480,29 @@ class BatchJoinTest(test_lib.TestCase):
[[sparse]], keep_input=[True, False], batch_size=2, enqueue_many=True)
self.assertIs(None, batched.dense_shape.get_shape().num_elements())
def testMaybeBatchCorrectValues(self):
sparse = sparse_tensor.SparseTensor(
indices=[[0, 1], [0, 2], [1, 0], [1, 3]],
dense_shape=[2, 4],
values=[5, 4, 7, 2])
keep = constant_op.constant([True, False])
batched = inp.maybe_batch_join(
[[sparse]], keep_input=keep, batch_size=1, enqueue_many=True)
with self.test_session():
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(coord=coord)
batched_np = batched.eval()
coord.request_stop()
for thread in threads:
thread.join()
self.assertAllEqual([[0, 1], [0, 2]], batched_np.indices)
self.assertAllEqual([5, 4], batched_np.values)
self.assertAllEqual([1, 4], batched_np.dense_shape)
class ShuffleBatchTest(test_lib.TestCase):

View File

@ -51,8 +51,8 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
variable: A Variable.
value: A tensor with the same shape as 'variable'.
decay: A float Tensor or float value. The moving average decay.
zero_debias: A python bool. If true, assume the variable is 0-initialized and
unbias it, as in https://arxiv.org/abs/1412.6980. See docstring in
zero_debias: A python bool. If true, assume the variable is 0-initialized
and unbias it, as in https://arxiv.org/abs/1412.6980. See docstring in
`_zero_debias` for more details.
name: Optional name of the returned operation.

View File

@ -507,7 +507,7 @@ def _parse_function_details(docstring):
pairs = list(_gen_pairs(parts[1:]))
function_details = []
item_re = re.compile(r'^ (\w+):', re.MULTILINE)
item_re = re.compile(r'^ (\*?\*?\w+):', re.MULTILINE)
for keyword, content in pairs:
content = item_re.split(content)