Cleanup: Ran clang-format on files in tensorflow/core/.../*.{cc,h}.
PiperOrigin-RevId: 183848459
This commit is contained in:
parent
88eb6c61ef
commit
7149a2e2e2
@ -13,11 +13,9 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||||
#include "tensorflow/core/graph/node_builder.h"
|
#include "tensorflow/core/graph/node_builder.h"
|
||||||
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
@ -44,7 +42,6 @@ Tensor make_zeros(const DataType& dtype, const TensorShapeProto& shape) {
|
|||||||
// third-party libraries aren't currently supported.
|
// third-party libraries aren't currently supported.
|
||||||
class AccumulateNV2RemovePass : public GraphOptimizationPass {
|
class AccumulateNV2RemovePass : public GraphOptimizationPass {
|
||||||
public:
|
public:
|
||||||
|
|
||||||
Status Run(const GraphOptimizationPassOptions& options) override {
|
Status Run(const GraphOptimizationPassOptions& options) override {
|
||||||
// TODO(freiss.oss@gmail.com): Substantial shared code with
|
// TODO(freiss.oss@gmail.com): Substantial shared code with
|
||||||
// ParallelConcatRemovePass::Run(). Consider refactoring if someone makes
|
// ParallelConcatRemovePass::Run(). Consider refactoring if someone makes
|
||||||
|
@ -127,10 +127,10 @@ class BFCAllocator : public VisitableAllocator {
|
|||||||
string DebugString(BFCAllocator* a,
|
string DebugString(BFCAllocator* a,
|
||||||
bool recurse) NO_THREAD_SAFETY_ANALYSIS {
|
bool recurse) NO_THREAD_SAFETY_ANALYSIS {
|
||||||
string dbg;
|
string dbg;
|
||||||
strings::StrAppend(&dbg, " Size: ", strings::HumanReadableNumBytes(size),
|
strings::StrAppend(
|
||||||
" | Requested Size: ",
|
&dbg, " Size: ", strings::HumanReadableNumBytes(size),
|
||||||
strings::HumanReadableNumBytes(requested_size),
|
" | Requested Size: ", strings::HumanReadableNumBytes(requested_size),
|
||||||
" | in_use: ", in_use());
|
" | in_use: ", in_use());
|
||||||
if (recurse && prev != BFCAllocator::kInvalidChunkHandle) {
|
if (recurse && prev != BFCAllocator::kInvalidChunkHandle) {
|
||||||
Chunk* p = a->ChunkFromHandle(prev);
|
Chunk* p = a->ChunkFromHandle(prev);
|
||||||
strings::StrAppend(&dbg, ", prev: ", p->DebugString(a, false));
|
strings::StrAppend(&dbg, ", prev: ", p->DebugString(a, false));
|
||||||
|
@ -88,7 +88,9 @@ TEST_F(DeviceSetTest, PrioritizedDeviceTypeList) {
|
|||||||
// D3 is prioritized below D1.
|
// D3 is prioritized below D1.
|
||||||
AddDevice("d3", "/job:a/replica:0/task:0/device:d3:0");
|
AddDevice("d3", "/job:a/replica:0/task:0/device:d3:0");
|
||||||
EXPECT_EQ((std::vector<DeviceType>{
|
EXPECT_EQ((std::vector<DeviceType>{
|
||||||
DeviceType("d2"), DeviceType("d1"), DeviceType("d3"),
|
DeviceType("d2"),
|
||||||
|
DeviceType("d1"),
|
||||||
|
DeviceType("d3"),
|
||||||
}),
|
}),
|
||||||
types());
|
types());
|
||||||
}
|
}
|
||||||
|
@ -61,7 +61,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/util/device_name_utils.h"
|
#include "tensorflow/core/util/device_name_utils.h"
|
||||||
#include "tensorflow/core/util/env_var.h"
|
#include "tensorflow/core/util/env_var.h"
|
||||||
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@ -472,9 +471,9 @@ Status DirectSession::Run(const RunOptions& run_options,
|
|||||||
Executor::Args args;
|
Executor::Args args;
|
||||||
args.step_id = step_id_counter_.fetch_add(1);
|
args.step_id = step_id_counter_.fetch_add(1);
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_tensor_names, output_names,
|
||||||
GetOrCreateExecutors(input_tensor_names, output_names, target_nodes,
|
target_nodes, &executors_and_keys,
|
||||||
&executors_and_keys, &run_state_args));
|
&run_state_args));
|
||||||
const int64 executor_step_count = executors_and_keys->step_count.fetch_add(1);
|
const int64 executor_step_count = executors_and_keys->step_count.fetch_add(1);
|
||||||
|
|
||||||
std::unique_ptr<DebuggerStateInterface> debugger_state;
|
std::unique_ptr<DebuggerStateInterface> debugger_state;
|
||||||
|
@ -436,10 +436,7 @@ TEST(DirectSessionTest, FetchMultipleTimes) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_OP("Darth")
|
REGISTER_OP("Darth").Input("x: float").Output("y: float").Doc(R"doc(
|
||||||
.Input("x: float")
|
|
||||||
.Output("y: float")
|
|
||||||
.Doc(R"doc(
|
|
||||||
Darth promises one return value.
|
Darth promises one return value.
|
||||||
|
|
||||||
x: float
|
x: float
|
||||||
@ -972,39 +969,38 @@ static void TestSessionInterOpThreadsImpl(bool use_function_lib,
|
|||||||
|
|
||||||
std::atomic<int32> num_done(0);
|
std::atomic<int32> num_done(0);
|
||||||
// Runs session to compute <node>:0 using inter_op thread pool <pool>.
|
// Runs session to compute <node>:0 using inter_op thread pool <pool>.
|
||||||
auto add_session_run_call = [use_global_pools, &def, &options, &sessions,
|
auto add_session_run_call =
|
||||||
&sessions_mu,
|
[use_global_pools, &def, &options, &sessions, &sessions_mu, &num_done](
|
||||||
&num_done](thread::ThreadPool* tp, Node* node,
|
thread::ThreadPool* tp, Node* node, int inter_op_pool) {
|
||||||
int inter_op_pool) {
|
auto fn = [use_global_pools, &def, &options, &sessions, &sessions_mu,
|
||||||
auto fn = [use_global_pools, &def, &options, &sessions, &sessions_mu,
|
inter_op_pool, node, &num_done]() {
|
||||||
inter_op_pool, node, &num_done]() {
|
RunOptions run_options;
|
||||||
RunOptions run_options;
|
run_options.set_inter_op_thread_pool(inter_op_pool);
|
||||||
run_options.set_inter_op_thread_pool(inter_op_pool);
|
std::vector<Tensor> outputs;
|
||||||
std::vector<Tensor> outputs;
|
|
||||||
|
|
||||||
Session* session;
|
Session* session;
|
||||||
if (use_global_pools) {
|
if (use_global_pools) {
|
||||||
std::unique_ptr<Session> s(NewSession(options));
|
std::unique_ptr<Session> s(NewSession(options));
|
||||||
TF_ASSERT_OK(s->Create(def));
|
TF_ASSERT_OK(s->Create(def));
|
||||||
session = s.get();
|
session = s.get();
|
||||||
|
|
||||||
mutex_lock l(sessions_mu);
|
mutex_lock l(sessions_mu);
|
||||||
sessions.emplace_back(std::move(s));
|
sessions.emplace_back(std::move(s));
|
||||||
} else {
|
} else {
|
||||||
session = sessions[0].get();
|
session = sessions[0].get();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status s = session->Run(run_options, {} /* inputs */,
|
Status s = session->Run(run_options, {} /* inputs */,
|
||||||
{node->name() + ":0"} /* output_names */, {},
|
{node->name() + ":0"} /* output_names */, {},
|
||||||
&outputs, nullptr /* run_metadata */);
|
&outputs, nullptr /* run_metadata */);
|
||||||
TF_CHECK_OK(s);
|
TF_CHECK_OK(s);
|
||||||
ASSERT_EQ(1, outputs.size());
|
ASSERT_EQ(1, outputs.size());
|
||||||
auto flat = outputs[0].flat<float>();
|
auto flat = outputs[0].flat<float>();
|
||||||
EXPECT_FLOAT_EQ(1.2, flat(0));
|
EXPECT_FLOAT_EQ(1.2, flat(0));
|
||||||
num_done.fetch_add(1);
|
num_done.fetch_add(1);
|
||||||
};
|
};
|
||||||
tp->Schedule(fn);
|
tp->Schedule(fn);
|
||||||
};
|
};
|
||||||
|
|
||||||
// For blocking states:
|
// For blocking states:
|
||||||
// - Starts at 0, BlockingOp::Compute will move to 1.
|
// - Starts at 0, BlockingOp::Compute will move to 1.
|
||||||
|
@ -161,14 +161,14 @@ static void TestHWAccelerator(bool enableHWTrace) {
|
|||||||
x->set_assigned_device_name("/job:localhost/replica:0/task:0/device:GPU:0");
|
x->set_assigned_device_name("/job:localhost/replica:0/task:0/device:GPU:0");
|
||||||
#ifdef TENSORFLOW_USE_SYCL
|
#ifdef TENSORFLOW_USE_SYCL
|
||||||
x->set_assigned_device_name("/job:localhost/replica:0/task:0/device:SYCL:0");
|
x->set_assigned_device_name("/job:localhost/replica:0/task:0/device:SYCL:0");
|
||||||
#endif // TENSORFLOW_USE_SYCL
|
#endif // TENSORFLOW_USE_SYCL
|
||||||
|
|
||||||
// y = A * x
|
// y = A * x
|
||||||
Node* y = test::graph::Matmul(&graph, a, x, false, false);
|
Node* y = test::graph::Matmul(&graph, a, x, false, false);
|
||||||
y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:GPU:0");
|
y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:GPU:0");
|
||||||
#ifdef TENSORFLOW_USE_SYCL
|
#ifdef TENSORFLOW_USE_SYCL
|
||||||
y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:SYCL:0");
|
y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:SYCL:0");
|
||||||
#endif // TENSORFLOW_USE_SYCL
|
#endif // TENSORFLOW_USE_SYCL
|
||||||
|
|
||||||
Node* y_neg = test::graph::Unary(&graph, "Neg", y);
|
Node* y_neg = test::graph::Unary(&graph, "Neg", y);
|
||||||
y_neg->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
|
y_neg->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
|
||||||
@ -181,7 +181,7 @@ y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:SYCL:0");
|
|||||||
(*options.config.mutable_device_count())["GPU"] = 1;
|
(*options.config.mutable_device_count())["GPU"] = 1;
|
||||||
#ifdef TENSORFLOW_USE_SYCL
|
#ifdef TENSORFLOW_USE_SYCL
|
||||||
(*options.config.mutable_device_count())["SYCL"] = 1;
|
(*options.config.mutable_device_count())["SYCL"] = 1;
|
||||||
#endif // TENSORFLOW_USE_SYCL
|
#endif // TENSORFLOW_USE_SYCL
|
||||||
options.config.set_allow_soft_placement(true);
|
options.config.set_allow_soft_placement(true);
|
||||||
options.config.mutable_graph_options()->set_build_cost_model(1);
|
options.config.mutable_graph_options()->set_build_cost_model(1);
|
||||||
std::unique_ptr<Session> session(NewSession(options));
|
std::unique_ptr<Session> session(NewSession(options));
|
||||||
|
@ -1609,7 +1609,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
|
|||||||
auto done = [this, state]() {
|
auto done = [this, state]() {
|
||||||
Device* device = impl_->params_.device;
|
Device* device = impl_->params_.device;
|
||||||
NodeExecStatsWrapper* stats = state->stats; // Shorthand
|
NodeExecStatsWrapper* stats = state->stats; // Shorthand
|
||||||
Entry* first_input = state->first_input; // Shorthand
|
Entry* first_input = state->first_input; // Shorthand
|
||||||
|
|
||||||
nodestats::SetOpEnd(stats);
|
nodestats::SetOpEnd(stats);
|
||||||
EntryVector outputs;
|
EntryVector outputs;
|
||||||
|
@ -205,7 +205,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
|
|||||||
// The instantiated and transformed function is encoded as a Graph
|
// The instantiated and transformed function is encoded as a Graph
|
||||||
// object, and an executor is created for the graph.
|
// object, and an executor is created for the graph.
|
||||||
struct Item : public core::RefCounted {
|
struct Item : public core::RefCounted {
|
||||||
const Graph* graph = nullptr; // Owned by exec.
|
const Graph* graph = nullptr; // Owned by exec.
|
||||||
const FunctionLibraryDefinition* overlay_lib = nullptr; // Not owned.
|
const FunctionLibraryDefinition* overlay_lib = nullptr; // Not owned.
|
||||||
FunctionBody* func_graph = nullptr;
|
FunctionBody* func_graph = nullptr;
|
||||||
Executor* exec = nullptr;
|
Executor* exec = nullptr;
|
||||||
|
@ -154,8 +154,9 @@ TEST(GPUBFCAllocatorTest, ExerciseCoalescing) {
|
|||||||
a.DeallocateRaw(t3);
|
a.DeallocateRaw(t3);
|
||||||
a.DeallocateRaw(t4);
|
a.DeallocateRaw(t4);
|
||||||
}
|
}
|
||||||
CheckStats(&a, 4097, 0, 1024 * sizeof(float) + 1048576 * sizeof(int64) +
|
CheckStats(&a, 4097, 0,
|
||||||
2048 * sizeof(double) + 10485760 * sizeof(float),
|
1024 * sizeof(float) + 1048576 * sizeof(int64) +
|
||||||
|
2048 * sizeof(double) + 10485760 * sizeof(float),
|
||||||
10485760 * sizeof(float));
|
10485760 * sizeof(float));
|
||||||
|
|
||||||
// At the end, we should have coalesced all memory into one region
|
// At the end, we should have coalesced all memory into one region
|
||||||
|
@ -763,8 +763,9 @@ int64 MinSystemMemory(int64 available_memory) {
|
|||||||
min_system_memory *= 2;
|
min_system_memory *= 2;
|
||||||
#endif
|
#endif
|
||||||
#if defined(NVIDIA_TEGRA)
|
#if defined(NVIDIA_TEGRA)
|
||||||
// 1GB system mem for NVIDIA Tegra devices since they use the same mem for RAM and Video RAM
|
// 1GB system mem for NVIDIA Tegra devices since they use the same mem for RAM
|
||||||
min_system_memory = 1<<30;
|
// and Video RAM
|
||||||
|
min_system_memory = 1 << 30;
|
||||||
#endif
|
#endif
|
||||||
return min_system_memory;
|
return min_system_memory;
|
||||||
}
|
}
|
||||||
|
@ -108,7 +108,8 @@ TEST_F(GpuStreamUtilTest, StreamOverrides) {
|
|||||||
ops::_Recv(root.WithOpName("input"), DT_FLOAT, "input", "/cpu:0", 0,
|
ops::_Recv(root.WithOpName("input"), DT_FLOAT, "input", "/cpu:0", 0,
|
||||||
"/device:GPU:0");
|
"/device:GPU:0");
|
||||||
Output n = ops::MatMul(root, {}, {});
|
Output n = ops::MatMul(root, {}, {});
|
||||||
ops::_Send(root.WithOpName("output"), n, "output", "/device:GPU:0", 0, "/cpu:0");
|
ops::_Send(root.WithOpName("output"), n, "output", "/device:GPU:0", 0,
|
||||||
|
"/cpu:0");
|
||||||
Graph g(OpRegistry::Global());
|
Graph g(OpRegistry::Global());
|
||||||
TF_ASSERT_OK(root.ToGraph(&g));
|
TF_ASSERT_OK(root.ToGraph(&g));
|
||||||
|
|
||||||
|
@ -88,8 +88,8 @@ ProcessState::~ProcessState() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
string ProcessState::MemDesc::DebugString() {
|
string ProcessState::MemDesc::DebugString() {
|
||||||
return strings::StrCat((loc == CPU ? "CPU " : "GPU "), dev_index, ", dma: ",
|
return strings::StrCat((loc == CPU ? "CPU " : "GPU "), dev_index,
|
||||||
gpu_registered, ", nic: ", nic_registered);
|
", dma: ", gpu_registered, ", nic: ", nic_registered);
|
||||||
}
|
}
|
||||||
|
|
||||||
ProcessState::MemDesc ProcessState::PtrType(const void* ptr) {
|
ProcessState::MemDesc ProcessState::PtrType(const void* ptr) {
|
||||||
|
@ -139,9 +139,7 @@ class GraphExecutionState {
|
|||||||
|
|
||||||
// The graph returned by BuildGraph may contain only the pruned
|
// The graph returned by BuildGraph may contain only the pruned
|
||||||
// graph, whereas some clients may want access to the full graph.
|
// graph, whereas some clients may want access to the full graph.
|
||||||
const Graph* full_graph() {
|
const Graph* full_graph() { return graph_; }
|
||||||
return graph_;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns the node with the given name, or null if it does not exist.
|
// Returns the node with the given name, or null if it does not exist.
|
||||||
const Node* get_node_by_name(const string& name) const {
|
const Node* get_node_by_name(const string& name) const {
|
||||||
|
@ -47,7 +47,7 @@ struct EndpointEq {
|
|||||||
static Status ProcessMemoryTypes(
|
static Status ProcessMemoryTypes(
|
||||||
const DeviceType& device_type, const Graph* g,
|
const DeviceType& device_type, const Graph* g,
|
||||||
const std::function<Status(const Edge*, MemoryType, MemoryType)>& fn) {
|
const std::function<Status(const Edge*, MemoryType, MemoryType)>& fn) {
|
||||||
if (device_type != DEVICE_GPU && device_type != DEVICE_SYCL ) {
|
if (device_type != DEVICE_GPU && device_type != DEVICE_SYCL) {
|
||||||
// On non-GPU and non-SYCL devices, HOST_MEMORY and DEVICE_MEMORY are always
|
// On non-GPU and non-SYCL devices, HOST_MEMORY and DEVICE_MEMORY are always
|
||||||
// compatible.
|
// compatible.
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -36,7 +36,7 @@ TEST(MemoryTypeChecker, Int32OK) {
|
|||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA
|
||||||
#ifdef TENSORFLOW_USE_SYCL
|
#ifdef TENSORFLOW_USE_SYCL
|
||||||
TF_EXPECT_OK(ValidateMemoryTypes(DEVICE_SYCL, g));
|
TF_EXPECT_OK(ValidateMemoryTypes(DEVICE_SYCL, g));
|
||||||
#endif // TENSORFLOW_USE_SYCL
|
#endif // TENSORFLOW_USE_SYCL
|
||||||
delete g;
|
delete g;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -64,7 +64,7 @@ TEST(MemoryTypeChecker, Int32NotOk) {
|
|||||||
// But we can insert _HostSend/_HostRecv to ensure the invariant.
|
// But we can insert _HostSend/_HostRecv to ensure the invariant.
|
||||||
TF_EXPECT_OK(EnsureMemoryTypes(DEVICE_SYCL, "/device:SYCL:0", g));
|
TF_EXPECT_OK(EnsureMemoryTypes(DEVICE_SYCL, "/device:SYCL:0", g));
|
||||||
TF_EXPECT_OK(ValidateMemoryTypes(DEVICE_SYCL, g));
|
TF_EXPECT_OK(ValidateMemoryTypes(DEVICE_SYCL, g));
|
||||||
#endif // TENSORFLOW_USE_SYCL
|
#endif // TENSORFLOW_USE_SYCL
|
||||||
delete g;
|
delete g;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -91,7 +91,7 @@ TEST(MemoryTypeChecker, MemoryTypeForOutput) {
|
|||||||
TF_EXPECT_OK(MemoryTypeForOutput(DEVICE_SYCL, g, si, 0, &memory_type));
|
TF_EXPECT_OK(MemoryTypeForOutput(DEVICE_SYCL, g, si, 0, &memory_type));
|
||||||
// int Switch's output on GPU has HOST_MEMORY constraint.
|
// int Switch's output on GPU has HOST_MEMORY constraint.
|
||||||
EXPECT_EQ(memory_type, HOST_MEMORY);
|
EXPECT_EQ(memory_type, HOST_MEMORY);
|
||||||
#endif // TENSORFLOW_USE_SYCL
|
#endif // TENSORFLOW_USE_SYCL
|
||||||
delete g;
|
delete g;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -88,9 +88,9 @@ class Placer {
|
|||||||
void AssignAndLog(int assigned_device, Node* node) const;
|
void AssignAndLog(int assigned_device, Node* node) const;
|
||||||
void LogDeviceAssignment(const Node* node) const;
|
void LogDeviceAssignment(const Node* node) const;
|
||||||
|
|
||||||
Graph* const graph_; // Not owned.
|
Graph* const graph_; // Not owned.
|
||||||
const DeviceSet* const devices_; // Not owned.
|
const DeviceSet* const devices_; // Not owned.
|
||||||
const SessionOptions* options_; // Not owned.
|
const SessionOptions* options_; // Not owned.
|
||||||
const bool log_device_placement_;
|
const bool log_device_placement_;
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(Placer);
|
TF_DISALLOW_COPY_AND_ASSIGN(Placer);
|
||||||
|
@ -619,9 +619,9 @@ TEST_F(PlacerTest, TestReferenceConnectionIgnoreInfeasible) {
|
|||||||
Node* input = ops::SourceOp(
|
Node* input = ops::SourceOp(
|
||||||
"TestDevice",
|
"TestDevice",
|
||||||
b.opts().WithName("in").WithDevice("/job:a/task:0/device:fakegpu:0"));
|
b.opts().WithName("in").WithDevice("/job:a/task:0/device:fakegpu:0"));
|
||||||
Node* var = ops::SourceOp("TestVariable",
|
Node* var =
|
||||||
b.opts().WithName("var_0").WithDevice(
|
ops::SourceOp("TestVariable", b.opts().WithName("var_0").WithDevice(
|
||||||
"/job:a/task:0/device:fakegpu:0"));
|
"/job:a/task:0/device:fakegpu:0"));
|
||||||
|
|
||||||
// This op is specified on CPU, but in practice will be ignored,
|
// This op is specified on CPU, but in practice will be ignored,
|
||||||
// because the reference edges forces it on GPU.
|
// because the reference edges forces it on GPU.
|
||||||
|
@ -60,8 +60,8 @@ const string RegisteredFactoriesErrorMessageLocked() {
|
|||||||
str_util::Join(factory_types, ", "), "}.");
|
str_util::Join(factory_types, ", "), "}.");
|
||||||
}
|
}
|
||||||
string SessionOptionsToString(const SessionOptions& options) {
|
string SessionOptionsToString(const SessionOptions& options) {
|
||||||
return strings::StrCat("target: \"", options.target, "\" config: ",
|
return strings::StrCat("target: \"", options.target,
|
||||||
ProtoShortDebugString(options.config));
|
"\" config: ", ProtoShortDebugString(options.config));
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
@ -226,22 +226,23 @@ void StepStatsCollector::BuildCostModel(
|
|||||||
if (node) {
|
if (node) {
|
||||||
for (int i = 0; i < stats.output_size(); ++i) {
|
for (int i = 0; i < stats.output_size(); ++i) {
|
||||||
const auto& output = stats.output(i);
|
const auto& output = stats.output(i);
|
||||||
cm->RecordMaxMemorySize(node, i, Bytes(output.tensor_description()
|
cm->RecordMaxMemorySize(node, i,
|
||||||
.allocation_description()
|
Bytes(output.tensor_description()
|
||||||
.allocated_bytes()),
|
.allocation_description()
|
||||||
|
.allocated_bytes()),
|
||||||
stats.output(i).tensor_description().shape(),
|
stats.output(i).tensor_description().shape(),
|
||||||
node->output_types()[i]);
|
node->output_types()[i]);
|
||||||
cm->RecordAllocationId(node, i, output.tensor_description()
|
cm->RecordAllocationId(node, i,
|
||||||
.allocation_description()
|
output.tensor_description()
|
||||||
.allocation_id());
|
.allocation_description()
|
||||||
|
.allocation_id());
|
||||||
}
|
}
|
||||||
cm->RecordMemoryStats(node, stats.memory_stats());
|
cm->RecordMemoryStats(node, stats.memory_stats());
|
||||||
// Use hardware stats to record the execution time if they're available,
|
// Use hardware stats to record the execution time if they're available,
|
||||||
// otherwise use the regular (less accurate) stats
|
// otherwise use the regular (less accurate) stats
|
||||||
string node_name = dev_stats.regular_stats->node_stats(i).node_name();
|
string node_name = dev_stats.regular_stats->node_stats(i).node_name();
|
||||||
if (dev_stats.hardware_stats &&
|
if (dev_stats.hardware_stats && name_to_hw_node_stats.find(node_name) !=
|
||||||
name_to_hw_node_stats.find(node_name) !=
|
name_to_hw_node_stats.end()) {
|
||||||
name_to_hw_node_stats.end()) {
|
|
||||||
const NodeExecStats& hw_stats = name_to_hw_node_stats[node_name];
|
const NodeExecStats& hw_stats = name_to_hw_node_stats[node_name];
|
||||||
cm->RecordMaxExecutionTime(
|
cm->RecordMaxExecutionTime(
|
||||||
node, Microseconds(hw_stats.op_end_rel_micros()));
|
node, Microseconds(hw_stats.op_end_rel_micros()));
|
||||||
|
@ -80,7 +80,7 @@ void SYCLAllocator::ClearStats() override {
|
|||||||
|
|
||||||
size_t SYCLAllocator::RequestedSize(void* ptr) {
|
size_t SYCLAllocator::RequestedSize(void* ptr) {
|
||||||
mutex_lock lock(mu_);
|
mutex_lock lock(mu_);
|
||||||
if(!sycl_device_) {
|
if (!sycl_device_) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
const auto& buffer = sycl_device_->get_sycl_buffer(ptr);
|
const auto& buffer = sycl_device_->get_sycl_buffer(ptr);
|
||||||
|
@ -20,10 +20,10 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_COMMON_RUNTIME_SYCL_SYCL_ALLOCATOR_H_
|
#ifndef TENSORFLOW_COMMON_RUNTIME_SYCL_SYCL_ALLOCATOR_H_
|
||||||
#define TENSORFLOW_COMMON_RUNTIME_SYCL_SYCL_ALLOCATOR_H_
|
#define TENSORFLOW_COMMON_RUNTIME_SYCL_SYCL_ALLOCATOR_H_
|
||||||
|
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
#include "tensorflow/core/framework/allocator.h"
|
#include "tensorflow/core/framework/allocator.h"
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -56,14 +56,13 @@ class SYCLAllocator : public Allocator {
|
|||||||
// Clear the SYCL device used by the Allocator
|
// Clear the SYCL device used by the Allocator
|
||||||
void ClearSYCLDevice() {
|
void ClearSYCLDevice() {
|
||||||
mutex_lock lock(mu_);
|
mutex_lock lock(mu_);
|
||||||
if(sycl_device_) {
|
if (sycl_device_) {
|
||||||
delete sycl_device_;
|
delete sycl_device_;
|
||||||
sycl_device_ = nullptr;
|
sycl_device_ = nullptr;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
mutable mutex mu_;
|
mutable mutex mu_;
|
||||||
Eigen::SyclDevice* sycl_device_ GUARDED_BY(mu_); // owned
|
Eigen::SyclDevice* sycl_device_ GUARDED_BY(mu_); // owned
|
||||||
AllocatorStats stats_ GUARDED_BY(mu_);
|
AllocatorStats stats_ GUARDED_BY(mu_);
|
||||||
|
@ -187,9 +187,9 @@ class GSYCLInterface {
|
|||||||
type = "Unknown";
|
type = "Unknown";
|
||||||
}
|
}
|
||||||
|
|
||||||
return strings::StrCat("id: ", device_id, ", type: ", type, ", name: ",
|
return strings::StrCat(
|
||||||
name.c_str(), ", vendor: ", vendor.c_str(),
|
"id: ", device_id, ", type: ", type, ", name: ", name.c_str(),
|
||||||
", profile: ", profile.c_str());
|
", vendor: ", vendor.c_str(), ", profile: ", profile.c_str());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -26,7 +26,6 @@ class SYCLDeviceFactory : public DeviceFactory {
|
|||||||
public:
|
public:
|
||||||
Status CreateDevices(const SessionOptions &options, const string &name_prefix,
|
Status CreateDevices(const SessionOptions &options, const string &name_prefix,
|
||||||
std::vector<Device *> *devices) override {
|
std::vector<Device *> *devices) override {
|
||||||
|
|
||||||
auto syclInterface = GSYCLInterface::instance();
|
auto syclInterface = GSYCLInterface::instance();
|
||||||
|
|
||||||
size_t n = 1;
|
size_t n = 1;
|
||||||
@ -37,13 +36,11 @@ class SYCLDeviceFactory : public DeviceFactory {
|
|||||||
|
|
||||||
for (int i = 0; i < n; i++) {
|
for (int i = 0; i < n; i++) {
|
||||||
string name = strings::StrCat(name_prefix, "/device:SYCL:", i);
|
string name = strings::StrCat(name_prefix, "/device:SYCL:", i);
|
||||||
devices->push_back(
|
devices->push_back(new SYCLDevice(
|
||||||
new SYCLDevice(options, name, Bytes(256 << 20), DeviceLocality()
|
options, name, Bytes(256 << 20), DeviceLocality(),
|
||||||
, syclInterface->GetShortDeviceDescription(i)
|
syclInterface->GetShortDeviceDescription(i),
|
||||||
, syclInterface->GetSYCLAllocator(i)
|
syclInterface->GetSYCLAllocator(i), syclInterface->GetCPUAllocator(i),
|
||||||
, syclInterface->GetCPUAllocator(i)
|
syclInterface->GetSYCLContext(i)));
|
||||||
, syclInterface->GetSYCLContext(i))
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -51,6 +48,6 @@ class SYCLDeviceFactory : public DeviceFactory {
|
|||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_LOCAL_DEVICE_FACTORY("SYCL", SYCLDeviceFactory, 200);
|
REGISTER_LOCAL_DEVICE_FACTORY("SYCL", SYCLDeviceFactory, 200);
|
||||||
}
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_USE_SYCL
|
#endif // TENSORFLOW_USE_SYCL
|
||||||
|
@ -20,8 +20,8 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_UTIL_H_
|
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_UTIL_H_
|
||||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_UTIL_H_
|
#define TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_UTIL_H_
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/device.h"
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
// For DMA helper
|
// For DMA helper
|
||||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
@ -24,31 +24,31 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
DebugGateway::DebugGateway(DirectSession* session) : session_(session) {
|
DebugGateway::DebugGateway(DirectSession* session) : session_(session) {
|
||||||
session_->node_outputs_callback_ = [this](
|
session_->node_outputs_callback_ =
|
||||||
const string& node_name, const int output_slot, const Tensor* tensor,
|
[this](const string& node_name, const int output_slot,
|
||||||
const bool is_ref, OpKernelContext* ctx) {
|
const Tensor* tensor, const bool is_ref, OpKernelContext* ctx) {
|
||||||
if (comp_cb_ != nullptr && output_slot <= 0) {
|
if (comp_cb_ != nullptr && output_slot <= 0) {
|
||||||
// The node completion callback is invoked once for a node regardless
|
// The node completion callback is invoked once for a node regardless
|
||||||
// of whether the node has zero, one or more outputs.
|
// of whether the node has zero, one or more outputs.
|
||||||
// The output_slot can be negative (-1, or kControlSlot) if
|
// The output_slot can be negative (-1, or kControlSlot) if
|
||||||
// node_outputs_callback_ is invoked for a node with no output. If that
|
// node_outputs_callback_ is invoked for a node with no output. If
|
||||||
// is the case, notify the callback that the node in question has no
|
// that is the case, notify the callback that the node in question has
|
||||||
// output.
|
// no output.
|
||||||
comp_cb_(node_name, output_slot == 0);
|
comp_cb_(node_name, output_slot == 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy tensor values (e.g., from GPU to host) only if the
|
// Copy tensor values (e.g., from GPU to host) only if the
|
||||||
// value callback is not nullptr.
|
// value callback is not nullptr.
|
||||||
if (val_cb_ != nullptr && output_slot >= 0) {
|
if (val_cb_ != nullptr && output_slot >= 0) {
|
||||||
CopyTensor(
|
CopyTensor(node_name, output_slot, tensor, ctx,
|
||||||
node_name, output_slot, tensor, ctx,
|
[this, node_name, output_slot,
|
||||||
[this, node_name, output_slot, is_ref](const Tensor* copied_tensor) {
|
is_ref](const Tensor* copied_tensor) {
|
||||||
val_cb_(node_name, output_slot, *copied_tensor, is_ref);
|
val_cb_(node_name, output_slot, *copied_tensor, is_ref);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
DebugGateway::~DebugGateway() {
|
DebugGateway::~DebugGateway() {
|
||||||
@ -86,7 +86,8 @@ void DebugGateway::CopyTensor(const string& node_name, const int output_slot,
|
|||||||
// Determine if the tensor is on device (GPU) or host (CPU).
|
// Determine if the tensor is on device (GPU) or host (CPU).
|
||||||
// The second part of the check is necessary because even an OpKernel on
|
// The second part of the check is necessary because even an OpKernel on
|
||||||
// may have output tensors allocated on CPU.
|
// may have output tensors allocated on CPU.
|
||||||
if ((device->name().find("GPU:") != string::npos || device->name().find("SYCL:") != string::npos) &&
|
if ((device->name().find("GPU:") != string::npos ||
|
||||||
|
device->name().find("SYCL:") != string::npos) &&
|
||||||
!ctx->output_alloc_attr(output_slot).on_host()) {
|
!ctx->output_alloc_attr(output_slot).on_host()) {
|
||||||
// GPU tensors: Copy it to host (CPU).
|
// GPU tensors: Copy it to host (CPU).
|
||||||
DeviceContext* device_ctxt = ctx->op_device_context();
|
DeviceContext* device_ctxt = ctx->op_device_context();
|
||||||
|
@ -390,9 +390,9 @@ TEST_F(SessionDebugMinusAXTest,
|
|||||||
debug_gateway.SetNodeValueCallback(
|
debug_gateway.SetNodeValueCallback(
|
||||||
[this, &mu, &val_callback_count, &a_debug_identity_node_name,
|
[this, &mu, &val_callback_count, &a_debug_identity_node_name,
|
||||||
&x_debug_identity_node_name, &y_debug_identity_node_name,
|
&x_debug_identity_node_name, &y_debug_identity_node_name,
|
||||||
&debug_identity_tensor_vals, &callbacks_done, &kConcurrentRuns](
|
&debug_identity_tensor_vals, &callbacks_done,
|
||||||
const string& node_name, const int output_slot,
|
&kConcurrentRuns](const string& node_name, const int output_slot,
|
||||||
const Tensor& tensor_value, const bool is_ref) {
|
const Tensor& tensor_value, const bool is_ref) {
|
||||||
mutex_lock l(mu);
|
mutex_lock l(mu);
|
||||||
|
|
||||||
if (node_name == a_debug_identity_node_name && output_slot == 0) {
|
if (node_name == a_debug_identity_node_name && output_slot == 0) {
|
||||||
@ -560,21 +560,21 @@ TEST_F(SessionDebugOutputSlotWithoutOutgoingEdgeTest,
|
|||||||
Notification callbacks_done;
|
Notification callbacks_done;
|
||||||
|
|
||||||
std::vector<Tensor> debug_identity_tensor_vals;
|
std::vector<Tensor> debug_identity_tensor_vals;
|
||||||
debug_gateway.SetNodeValueCallback([this, &mu, &callbacks_done,
|
debug_gateway.SetNodeValueCallback(
|
||||||
&debug_identity_node_name,
|
[this, &mu, &callbacks_done, &debug_identity_node_name,
|
||||||
&debug_identity_tensor_vals](
|
&debug_identity_tensor_vals](
|
||||||
const string& node_name, const int output_slot,
|
const string& node_name, const int output_slot,
|
||||||
const Tensor& tensor_value, const bool is_ref) {
|
const Tensor& tensor_value, const bool is_ref) {
|
||||||
mutex_lock l(mu);
|
mutex_lock l(mu);
|
||||||
|
|
||||||
if (node_name == debug_identity_node_name && output_slot == 0) {
|
if (node_name == debug_identity_node_name && output_slot == 0) {
|
||||||
debug_identity_tensor_vals.push_back(tensor_value);
|
debug_identity_tensor_vals.push_back(tensor_value);
|
||||||
|
|
||||||
if (!callbacks_done.HasBeenNotified()) {
|
if (!callbacks_done.HasBeenNotified()) {
|
||||||
callbacks_done.Notify();
|
callbacks_done.Notify();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
// Add DebugIdentity watch on c:0, which does not have an outgoing edge.
|
// Add DebugIdentity watch on c:0, which does not have an outgoing edge.
|
||||||
RunOptions run_opts;
|
RunOptions run_opts;
|
||||||
|
@ -30,7 +30,7 @@ namespace test {
|
|||||||
|
|
||||||
::grpc::Status TestEventListenerImpl::SendEvents(
|
::grpc::Status TestEventListenerImpl::SendEvents(
|
||||||
::grpc::ServerContext* context,
|
::grpc::ServerContext* context,
|
||||||
::grpc::ServerReaderWriter< ::tensorflow::EventReply, ::tensorflow::Event>*
|
::grpc::ServerReaderWriter<::tensorflow::EventReply, ::tensorflow::Event>*
|
||||||
stream) {
|
stream) {
|
||||||
Event event;
|
Event event;
|
||||||
|
|
||||||
|
@ -57,7 +57,8 @@ class DebugIOUtilsTest : public ::testing::Test {
|
|||||||
TEST_F(DebugIOUtilsTest, ConstructDebugNodeKey) {
|
TEST_F(DebugIOUtilsTest, ConstructDebugNodeKey) {
|
||||||
DebugNodeKey debug_node_key("/job:worker/replica:1/task:0/device:GPU:2",
|
DebugNodeKey debug_node_key("/job:worker/replica:1/task:0/device:GPU:2",
|
||||||
"hidden_1/MatMul", 0, "DebugIdentity");
|
"hidden_1/MatMul", 0, "DebugIdentity");
|
||||||
EXPECT_EQ("/job:worker/replica:1/task:0/device:GPU:2", debug_node_key.device_name);
|
EXPECT_EQ("/job:worker/replica:1/task:0/device:GPU:2",
|
||||||
|
debug_node_key.device_name);
|
||||||
EXPECT_EQ("hidden_1/MatMul", debug_node_key.node_name);
|
EXPECT_EQ("hidden_1/MatMul", debug_node_key.node_name);
|
||||||
EXPECT_EQ(0, debug_node_key.output_slot);
|
EXPECT_EQ(0, debug_node_key.output_slot);
|
||||||
EXPECT_EQ("DebugIdentity", debug_node_key.debug_op);
|
EXPECT_EQ("DebugIdentity", debug_node_key.debug_op);
|
||||||
|
@ -140,7 +140,7 @@ class GraphMgr {
|
|||||||
GraphMgr* graph_mgr;
|
GraphMgr* graph_mgr;
|
||||||
};
|
};
|
||||||
|
|
||||||
const WorkerEnv* worker_env_; // Not owned.
|
const WorkerEnv* worker_env_; // Not owned.
|
||||||
DeviceMgr* device_mgr_;
|
DeviceMgr* device_mgr_;
|
||||||
|
|
||||||
CostModelManager cost_model_manager_;
|
CostModelManager cost_model_manager_;
|
||||||
|
@ -528,8 +528,8 @@ void Master::ListDevices(const ListDevicesRequest* req,
|
|||||||
auto session = FindMasterSession(req->session_handle());
|
auto session = FindMasterSession(req->session_handle());
|
||||||
if (session == nullptr) {
|
if (session == nullptr) {
|
||||||
done(errors::InvalidArgument(
|
done(errors::InvalidArgument(
|
||||||
"Session ", req->session_handle(),
|
"Session ", req->session_handle(),
|
||||||
" is not found. Possibly, this master has restarted."));
|
" is not found. Possibly, this master has restarted."));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
core::ScopedUnref ref(session);
|
core::ScopedUnref ref(session);
|
||||||
|
@ -61,7 +61,7 @@ class MasterTest : public ::testing::Test {
|
|||||||
// rpc calls.
|
// rpc calls.
|
||||||
|
|
||||||
Status CreateSession(const GraphDef& def, string* handle,
|
Status CreateSession(const GraphDef& def, string* handle,
|
||||||
int64* initial_version) {
|
int64* initial_version) {
|
||||||
::grpc::ClientContext ctx;
|
::grpc::ClientContext ctx;
|
||||||
CreateSessionRequest req;
|
CreateSessionRequest req;
|
||||||
*(req.mutable_graph_def()) = def;
|
*(req.mutable_graph_def()) = def;
|
||||||
@ -77,7 +77,7 @@ class MasterTest : public ::testing::Test {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status ExtendSession(const string& handle, const GraphDef& def,
|
Status ExtendSession(const string& handle, const GraphDef& def,
|
||||||
int64 current_version, int64* new_version) {
|
int64 current_version, int64* new_version) {
|
||||||
::grpc::ClientContext ctx;
|
::grpc::ClientContext ctx;
|
||||||
ExtendSessionRequest req;
|
ExtendSessionRequest req;
|
||||||
req.set_session_handle(handle);
|
req.set_session_handle(handle);
|
||||||
|
@ -185,23 +185,22 @@ class GrpcMasterService : public AsyncServiceInterface {
|
|||||||
MutableRunStepResponseWrapper* wrapped_response =
|
MutableRunStepResponseWrapper* wrapped_response =
|
||||||
new NonOwnedProtoRunStepResponse(&call->response);
|
new NonOwnedProtoRunStepResponse(&call->response);
|
||||||
call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
|
call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
|
||||||
master_impl_->RunStep(call_opts, wrapped_request, wrapped_response,
|
master_impl_->RunStep(
|
||||||
[call, call_opts, wrapped_request, wrapped_response,
|
call_opts, wrapped_request, wrapped_response,
|
||||||
trace](const Status& status) {
|
[call, call_opts, wrapped_request, wrapped_response,
|
||||||
call->ClearCancelCallback();
|
trace](const Status& status) {
|
||||||
delete call_opts;
|
call->ClearCancelCallback();
|
||||||
delete wrapped_request;
|
delete call_opts;
|
||||||
delete trace;
|
delete wrapped_request;
|
||||||
if (call->request.store_errors_in_response_body() &&
|
delete trace;
|
||||||
!status.ok()) {
|
if (call->request.store_errors_in_response_body() && !status.ok()) {
|
||||||
call->response.set_status_code(status.code());
|
call->response.set_status_code(status.code());
|
||||||
call->response.set_status_error_message(
|
call->response.set_status_error_message(status.error_message());
|
||||||
status.error_message());
|
call->SendResponse(ToGrpcStatus(Status::OK()));
|
||||||
call->SendResponse(ToGrpcStatus(Status::OK()));
|
} else {
|
||||||
} else {
|
call->SendResponse(ToGrpcStatus(status));
|
||||||
call->SendResponse(ToGrpcStatus(status));
|
}
|
||||||
}
|
});
|
||||||
});
|
|
||||||
ENQUEUE_REQUEST(RunStep, true);
|
ENQUEUE_REQUEST(RunStep, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -89,9 +89,9 @@ class MasterService final {
|
|||||||
::grpc::Status ExtendSession(::grpc::ClientContext* context,
|
::grpc::Status ExtendSession(::grpc::ClientContext* context,
|
||||||
const ExtendSessionRequest& request,
|
const ExtendSessionRequest& request,
|
||||||
ExtendSessionResponse* response) override;
|
ExtendSessionResponse* response) override;
|
||||||
::grpc::Status PartialRunSetup(
|
::grpc::Status PartialRunSetup(::grpc::ClientContext* context,
|
||||||
::grpc::ClientContext* context, const PartialRunSetupRequest& request,
|
const PartialRunSetupRequest& request,
|
||||||
PartialRunSetupResponse* response) override;
|
PartialRunSetupResponse* response) override;
|
||||||
::grpc::Status RunStep(::grpc::ClientContext* context,
|
::grpc::Status RunStep(::grpc::ClientContext* context,
|
||||||
const RunStepRequest& request,
|
const RunStepRequest& request,
|
||||||
RunStepResponse* response) override;
|
RunStepResponse* response) override;
|
||||||
|
@ -69,8 +69,7 @@ class GrpcRemoteMaster : public MasterInterface {
|
|||||||
::grpc::ClientContext ctx;
|
::grpc::ClientContext ctx;
|
||||||
auto trace = TraceRpc("RunStep/Client", &ctx);
|
auto trace = TraceRpc("RunStep/Client", &ctx);
|
||||||
return Call(&ctx, call_options, &request->ToProto(),
|
return Call(&ctx, call_options, &request->ToProto(),
|
||||||
get_proto_from_wrapper(response),
|
get_proto_from_wrapper(response), &MasterServiceStub::RunStep);
|
||||||
&MasterServiceStub::RunStep);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CloseSession(CallOptions* call_options,
|
Status CloseSession(CallOptions* call_options,
|
||||||
@ -114,8 +113,9 @@ class GrpcRemoteMaster : public MasterInterface {
|
|||||||
template <typename Request, typename Response>
|
template <typename Request, typename Response>
|
||||||
Status Call(::grpc::ClientContext* ctx, CallOptions* call_options,
|
Status Call(::grpc::ClientContext* ctx, CallOptions* call_options,
|
||||||
const Request* request, Response* response,
|
const Request* request, Response* response,
|
||||||
::grpc::Status (MasterServiceStub::*pfunc)(
|
::grpc::Status (MasterServiceStub::*pfunc)(::grpc::ClientContext*,
|
||||||
::grpc::ClientContext*, const Request&, Response*)) {
|
const Request&,
|
||||||
|
Response*)) {
|
||||||
ctx->set_fail_fast(false);
|
ctx->set_fail_fast(false);
|
||||||
SetDeadline(ctx, call_options->GetTimeout());
|
SetDeadline(ctx, call_options->GetTimeout());
|
||||||
return FromGrpcStatus((stub_.get()->*pfunc)(ctx, *request, response));
|
return FromGrpcStatus((stub_.get()->*pfunc)(ctx, *request, response));
|
||||||
|
@ -21,11 +21,8 @@ namespace tensorflow {
|
|||||||
namespace test {
|
namespace test {
|
||||||
|
|
||||||
// ErrorOp::Compute returns an error.
|
// ErrorOp::Compute returns an error.
|
||||||
REGISTER_OP("Error")
|
REGISTER_OP("Error").Input("in: T").Output("out: T").Attr("T: type").Attr(
|
||||||
.Input("in: T")
|
"message: string");
|
||||||
.Output("out: T")
|
|
||||||
.Attr("T: type")
|
|
||||||
.Attr("message: string");
|
|
||||||
class ErrorOp : public OpKernel {
|
class ErrorOp : public OpKernel {
|
||||||
public:
|
public:
|
||||||
explicit ErrorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
explicit ErrorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||||
@ -66,11 +63,8 @@ REGISTER_KERNEL_BUILDER(Name("InvalidRefType").Device(DEVICE_CPU),
|
|||||||
|
|
||||||
// DelayOp::AsyncCompute sleeps for "micros"-econd and then returns
|
// DelayOp::AsyncCompute sleeps for "micros"-econd and then returns
|
||||||
// its input.
|
// its input.
|
||||||
REGISTER_OP("Delay")
|
REGISTER_OP("Delay").Input("in: T").Output("out: T").Attr("T: type").Attr(
|
||||||
.Input("in: T")
|
"micros: int");
|
||||||
.Output("out: T")
|
|
||||||
.Attr("T: type")
|
|
||||||
.Attr("micros: int");
|
|
||||||
class DelayOp : public AsyncOpKernel {
|
class DelayOp : public AsyncOpKernel {
|
||||||
public:
|
public:
|
||||||
explicit DelayOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
|
explicit DelayOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
|
||||||
|
@ -184,8 +184,8 @@ static void BM_Helper(int iters, int width, int num_stages, int tensor_size,
|
|||||||
|
|
||||||
testing::SetLabel(
|
testing::SetLabel(
|
||||||
strings::StrCat(def.node_size(), " nodes; ",
|
strings::StrCat(def.node_size(), " nodes; ",
|
||||||
use_multiple_devices ? "Multi device" : "Single device",
|
use_multiple_devices ? "Multi device" : "Single device",
|
||||||
"; tensor bytes/send: ", tensor_size * sizeof(float)));
|
"; tensor bytes/send: ", tensor_size * sizeof(float)));
|
||||||
|
|
||||||
std::vector<Tensor> outputs;
|
std::vector<Tensor> outputs;
|
||||||
|
|
||||||
|
@ -17,9 +17,9 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <queue>
|
#include <queue>
|
||||||
|
|
||||||
#include "tensorflow/core/graph/graph.h"
|
|
||||||
#include "tensorflow/core/common_runtime/device.h"
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
#include "tensorflow/core/common_runtime/device_set.h"
|
#include "tensorflow/core/common_runtime/device_set.h"
|
||||||
|
#include "tensorflow/core/graph/graph.h"
|
||||||
#include "tensorflow/core/util/util.h"
|
#include "tensorflow/core/util/util.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
@ -16,15 +16,15 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SCHEDULER_H_
|
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SCHEDULER_H_
|
||||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SCHEDULER_H_
|
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SCHEDULER_H_
|
||||||
|
|
||||||
#include <functional>
|
|
||||||
#include <deque>
|
#include <deque>
|
||||||
|
#include <functional>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/core/graph/costmodel.h"
|
|
||||||
#include "tensorflow/core/common_runtime/device.h"
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
#include "tensorflow/core/common_runtime/device_set.h"
|
#include "tensorflow/core/common_runtime/device_set.h"
|
||||||
|
#include "tensorflow/core/graph/costmodel.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
@ -97,9 +97,8 @@ void WorkerCacheLogger::RecordDataTransfer(int64 step_id, int64 start_usecs,
|
|||||||
const string& tensor_name,
|
const string& tensor_name,
|
||||||
const string& src_device,
|
const string& src_device,
|
||||||
const string& dst_device,
|
const string& dst_device,
|
||||||
int64 bytes,
|
int64 bytes, const string& details,
|
||||||
const string& details,
|
const string& transfer_method_name) {
|
||||||
const string& transfer_method_name){
|
|
||||||
NodeExecStats* ns = new NodeExecStats;
|
NodeExecStats* ns = new NodeExecStats;
|
||||||
ns->set_node_name(transfer_method_name);
|
ns->set_node_name(transfer_method_name);
|
||||||
if (details.empty()) {
|
if (details.empty()) {
|
||||||
|
@ -158,8 +158,8 @@ void CostModel::SetNumOutputs(const Node* node, int num_outputs) {
|
|||||||
Ensure(id, 0);
|
Ensure(id, 0);
|
||||||
auto perslot = &slot_bytes_[id];
|
auto perslot = &slot_bytes_[id];
|
||||||
if (!perslot->empty()) {
|
if (!perslot->empty()) {
|
||||||
CHECK_EQ(num_outputs, perslot->size()) << "Cannot resize slot_bytes, node="
|
CHECK_EQ(num_outputs, perslot->size())
|
||||||
<< node->name();
|
<< "Cannot resize slot_bytes, node=" << node->name();
|
||||||
}
|
}
|
||||||
Ensure(id, num_outputs);
|
Ensure(id, num_outputs);
|
||||||
}
|
}
|
||||||
|
@ -198,7 +198,7 @@ class CostModel {
|
|||||||
// Cumulative execution time.
|
// Cumulative execution time.
|
||||||
std::vector<Microseconds> time_;
|
std::vector<Microseconds> time_;
|
||||||
// Cumulative Bytes output on each channel.
|
// Cumulative Bytes output on each channel.
|
||||||
std::vector<gtl::InlinedVector<Bytes, 2> > slot_bytes_;
|
std::vector<gtl::InlinedVector<Bytes, 2>> slot_bytes_;
|
||||||
|
|
||||||
// Maximum execution time
|
// Maximum execution time
|
||||||
std::vector<Microseconds> max_exec_time_;
|
std::vector<Microseconds> max_exec_time_;
|
||||||
@ -217,7 +217,7 @@ class CostModel {
|
|||||||
};
|
};
|
||||||
std::vector<MemUsage> max_mem_usage_;
|
std::vector<MemUsage> max_mem_usage_;
|
||||||
|
|
||||||
std::vector<gtl::InlinedVector<int64, 2> > output_port_alloc_ids_;
|
std::vector<gtl::InlinedVector<int64, 2>> output_port_alloc_ids_;
|
||||||
|
|
||||||
std::set<int64> persistent_alloc_ids_;
|
std::set<int64> persistent_alloc_ids_;
|
||||||
std::map<string, std::set<int64>> persistent_alloc_ids_by_devices_;
|
std::map<string, std::set<int64>> persistent_alloc_ids_by_devices_;
|
||||||
|
@ -62,8 +62,8 @@ class Node;
|
|||||||
class VersionDef;
|
class VersionDef;
|
||||||
class WhileContext;
|
class WhileContext;
|
||||||
|
|
||||||
class NeighborIter; // Declared below
|
class NeighborIter; // Declared below
|
||||||
class NodeIter; // Declared below
|
class NodeIter; // Declared below
|
||||||
class NodeProperties; // Defined in .cc
|
class NodeProperties; // Defined in .cc
|
||||||
|
|
||||||
class Node {
|
class Node {
|
||||||
|
@ -26,7 +26,6 @@ namespace tensorflow {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
TEST(GraphDefBuilderTest, Version) {
|
TEST(GraphDefBuilderTest, Version) {
|
||||||
|
|
||||||
// Verify that our assertions will be nontrivial
|
// Verify that our assertions will be nontrivial
|
||||||
ASSERT_LT(0, TF_GRAPH_DEF_VERSION);
|
ASSERT_LT(0, TF_GRAPH_DEF_VERSION);
|
||||||
|
|
||||||
|
@ -21,102 +21,101 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
// Since our ops are going to produce and also consume N addition tensors
|
// Since our ops are going to produce and also consume N addition tensors
|
||||||
// (Mkl) for N Tensorflow tensors, we can have following different
|
// (Mkl) for N Tensorflow tensors, we can have following different
|
||||||
// orderings among these 2N tensors.
|
// orderings among these 2N tensors.
|
||||||
//
|
//
|
||||||
// E.g., for Tensorflow tensors A, B, and C, our ops will produce and
|
// E.g., for Tensorflow tensors A, B, and C, our ops will produce and
|
||||||
// consume A_m, B_m, and C_m additionally.
|
// consume A_m, B_m, and C_m additionally.
|
||||||
//
|
//
|
||||||
// INTERLEAVED: in this case 2N tensors are interleaved. So for above
|
// INTERLEAVED: in this case 2N tensors are interleaved. So for above
|
||||||
// example, the ordering looks like: A, A_m, B, B_m, C, C_m.
|
// example, the ordering looks like: A, A_m, B, B_m, C, C_m.
|
||||||
//
|
//
|
||||||
// CONTIGUOUS: in thi case N Tensorflow tensors are contiguous followed
|
// CONTIGUOUS: in thi case N Tensorflow tensors are contiguous followed
|
||||||
// by N Mkl tensors. So for above example, the ordering looks
|
// by N Mkl tensors. So for above example, the ordering looks
|
||||||
// like: A, B, C, A_m, B_m, C_m
|
// like: A, B, C, A_m, B_m, C_m
|
||||||
//
|
//
|
||||||
// Following APIs map index of original Tensorflow tensors to their
|
// Following APIs map index of original Tensorflow tensors to their
|
||||||
// appropriate position based on selected ordering. For contiguous ordering,
|
// appropriate position based on selected ordering. For contiguous ordering,
|
||||||
// we need to know the total number of tensors (parameter total).
|
// we need to know the total number of tensors (parameter total).
|
||||||
//
|
//
|
||||||
typedef enum { TENSORS_INTERLEAVED, TENSORS_CONTIGUOUS } MklTfTensorOrdering;
|
typedef enum { TENSORS_INTERLEAVED, TENSORS_CONTIGUOUS } MklTfTensorOrdering;
|
||||||
// NOTE: Currently, we use contiguous ordering. If you change this, then you
|
// NOTE: Currently, we use contiguous ordering. If you change this, then you
|
||||||
// would need to change Mkl op definitions in nn_ops.cc.
|
// would need to change Mkl op definitions in nn_ops.cc.
|
||||||
static MklTfTensorOrdering kTensorOrdering = TENSORS_CONTIGUOUS;
|
static MklTfTensorOrdering kTensorOrdering = TENSORS_CONTIGUOUS;
|
||||||
|
|
||||||
// Get index of MetaData tensor from index 'n' of Data tensor.
|
// Get index of MetaData tensor from index 'n' of Data tensor.
|
||||||
inline int DataIndexToMetaDataIndex(int n, int total_tensors) {
|
inline int DataIndexToMetaDataIndex(int n, int total_tensors) {
|
||||||
if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
|
if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
|
||||||
// For interleaved ordering, Mkl tensor follows immediately after
|
// For interleaved ordering, Mkl tensor follows immediately after
|
||||||
// Tensorflow tensor.
|
// Tensorflow tensor.
|
||||||
return n + 1;
|
return n + 1;
|
||||||
} else {
|
} else {
|
||||||
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
|
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
|
||||||
// For contiguous ordering, Mkl tensor is n+total_tensors / 2 away.
|
// For contiguous ordering, Mkl tensor is n+total_tensors / 2 away.
|
||||||
return n + total_tensors / 2;
|
return n + total_tensors / 2;
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
int inline GetTensorDataIndex(int n, int total_tensors) {
|
int inline GetTensorDataIndex(int n, int total_tensors) {
|
||||||
if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
|
if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
|
||||||
return 2 * n; // index corresponding to nth input/output tensor
|
return 2 * n; // index corresponding to nth input/output tensor
|
||||||
} else {
|
} else {
|
||||||
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
|
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
|
||||||
return n;
|
return n;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int inline GetTensorMetaDataIndex(int n, int total_tensors) {
|
int inline GetTensorMetaDataIndex(int n, int total_tensors) {
|
||||||
// Get index for TensorData first and then use mapping function
|
// Get index for TensorData first and then use mapping function
|
||||||
// to get TensorMetaData index from TensorData index.
|
// to get TensorMetaData index from TensorData index.
|
||||||
int tidx = GetTensorDataIndex(n, total_tensors);
|
int tidx = GetTensorDataIndex(n, total_tensors);
|
||||||
return DataIndexToMetaDataIndex(tidx, total_tensors);
|
return DataIndexToMetaDataIndex(tidx, total_tensors);
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace mkl_op_registry {
|
namespace mkl_op_registry {
|
||||||
static const char* kMklOpLabel = "MklOp";
|
static const char* kMklOpLabel = "MklOp";
|
||||||
static const char* kMklOpLabelPattern = "label='MklOp'";
|
static const char* kMklOpLabelPattern = "label='MklOp'";
|
||||||
// Prefix that we add to Tensorflow op name to construct Mkl op name.
|
// Prefix that we add to Tensorflow op name to construct Mkl op name.
|
||||||
static const char* const kMklOpPrefix = "_Mkl";
|
static const char* const kMklOpPrefix = "_Mkl";
|
||||||
|
|
||||||
// Get the name of Mkl op from original TensorFlow op
|
// Get the name of Mkl op from original TensorFlow op
|
||||||
// We prefix 'Mkl' to the original op to get Mkl op.
|
// We prefix 'Mkl' to the original op to get Mkl op.
|
||||||
inline string GetMklOpName(const string& name) {
|
inline string GetMklOpName(const string& name) {
|
||||||
return string(kMklOpPrefix) + name;
|
return string(kMklOpPrefix) + name;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check whether opname with type T is registered as MKL-compliant.
|
||||||
|
//
|
||||||
|
// @input: name of the op
|
||||||
|
// @input: T datatype to be used for checking op
|
||||||
|
// @return: true if opname is registered as Mkl op; false otherwise
|
||||||
|
static inline bool IsMklOp(const std::string& op_name, DataType T) {
|
||||||
|
string kernel = KernelsRegisteredForOp(op_name);
|
||||||
|
bool result =
|
||||||
|
kernel.find(kMklOpLabelPattern) != string::npos && (T == DT_FLOAT);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check whether opname with type T is registered as MKL-compliant and
|
||||||
|
// is element-wise.
|
||||||
|
//
|
||||||
|
// @input: name of the op
|
||||||
|
// @input: T datatype to be used for checking op
|
||||||
|
// @return: true if opname is registered as element-wise Mkl op;
|
||||||
|
// false otherwise
|
||||||
|
static inline bool IsMklElementWiseOp(const std::string& op_name, DataType T) {
|
||||||
|
if (!IsMklOp(op_name, T)) {
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
bool result = (0 == op_name.compare(GetMklOpName("Add")) ||
|
||||||
|
0 == op_name.compare(GetMklOpName("Sub")) ||
|
||||||
|
0 == op_name.compare(GetMklOpName("Mul")) ||
|
||||||
|
0 == op_name.compare(GetMklOpName("Maximum")) ||
|
||||||
|
0 == op_name.compare(GetMklOpName("SquaredDifference")));
|
||||||
|
|
||||||
// Check whether opname with type T is registered as MKL-compliant.
|
return result;
|
||||||
//
|
}
|
||||||
// @input: name of the op
|
|
||||||
// @input: T datatype to be used for checking op
|
|
||||||
// @return: true if opname is registered as Mkl op; false otherwise
|
|
||||||
static inline bool IsMklOp(const std::string& op_name, DataType T) {
|
|
||||||
string kernel = KernelsRegisteredForOp(op_name);
|
|
||||||
bool result =
|
|
||||||
kernel.find(kMklOpLabelPattern) != string::npos && (T == DT_FLOAT);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check whether opname with type T is registered as MKL-compliant and
|
|
||||||
// is element-wise.
|
|
||||||
//
|
|
||||||
// @input: name of the op
|
|
||||||
// @input: T datatype to be used for checking op
|
|
||||||
// @return: true if opname is registered as element-wise Mkl op;
|
|
||||||
// false otherwise
|
|
||||||
static inline bool IsMklElementWiseOp(const std::string& op_name,
|
|
||||||
DataType T) {
|
|
||||||
if (!IsMklOp(op_name, T)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
bool result = (0 == op_name.compare(GetMklOpName("Add")) ||
|
|
||||||
0 == op_name.compare(GetMklOpName("Sub")) ||
|
|
||||||
0 == op_name.compare(GetMklOpName("Mul")) ||
|
|
||||||
0 == op_name.compare(GetMklOpName("Maximum")) ||
|
|
||||||
0 == op_name.compare(GetMklOpName("SquaredDifference")));
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
} // namespace mkl_op_registry
|
} // namespace mkl_op_registry
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
#endif // INTEL_MKL
|
#endif // INTEL_MKL
|
||||||
|
@ -37,8 +37,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/util/tensor_format.h"
|
#include "tensorflow/core/util/tensor_format.h"
|
||||||
|
|
||||||
#include "tensorflow/core/graph/mkl_layout_pass.h"
|
|
||||||
#include "tensorflow/core/graph/mkl_graph_util.h"
|
#include "tensorflow/core/graph/mkl_graph_util.h"
|
||||||
|
#include "tensorflow/core/graph/mkl_layout_pass.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -281,7 +281,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
csinfo_.mkl_conv2d_grad_filter = "_MklConv2DBackpropFilter";
|
csinfo_.mkl_conv2d_grad_filter = "_MklConv2DBackpropFilter";
|
||||||
csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias";
|
csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias";
|
||||||
csinfo_.mkl_conv2d_with_bias_backprop_bias =
|
csinfo_.mkl_conv2d_with_bias_backprop_bias =
|
||||||
"_MklConv2DWithBiasBackpropBias";
|
"_MklConv2DWithBiasBackpropBias";
|
||||||
csinfo_.relu = "Relu";
|
csinfo_.relu = "Relu";
|
||||||
csinfo_.relu_grad = "ReluGrad";
|
csinfo_.relu_grad = "ReluGrad";
|
||||||
csinfo_.reshape = "Reshape";
|
csinfo_.reshape = "Reshape";
|
||||||
@ -297,10 +297,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
// End - element-wise ops. See note above.
|
// End - element-wise ops. See note above.
|
||||||
|
|
||||||
// NOTE: names are alphabetically sorted.
|
// NOTE: names are alphabetically sorted.
|
||||||
rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn), CopyAttrsAddN,
|
rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn),
|
||||||
AddNRewrite, nullptr});
|
CopyAttrsAddN, AddNRewrite, nullptr});
|
||||||
rinfo_.push_back({csinfo_.add,
|
rinfo_.push_back({csinfo_.add, mkl_op_registry::GetMklOpName(csinfo_.add),
|
||||||
mkl_op_registry::GetMklOpName(csinfo_.add),
|
|
||||||
CopyAttrsDataType, AlwaysRewrite, nullptr});
|
CopyAttrsDataType, AlwaysRewrite, nullptr});
|
||||||
rinfo_.push_back({csinfo_.avg_pool,
|
rinfo_.push_back({csinfo_.avg_pool,
|
||||||
mkl_op_registry::GetMklOpName(csinfo_.avg_pool),
|
mkl_op_registry::GetMklOpName(csinfo_.avg_pool),
|
||||||
@ -337,14 +336,14 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
rinfo_.push_back({csinfo_.fused_batch_norm,
|
rinfo_.push_back({csinfo_.fused_batch_norm,
|
||||||
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm),
|
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm),
|
||||||
CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr});
|
CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr});
|
||||||
rinfo_.push_back({csinfo_.fused_batch_norm_grad,
|
rinfo_.push_back(
|
||||||
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad),
|
{csinfo_.fused_batch_norm_grad,
|
||||||
CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr});
|
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad),
|
||||||
|
CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr});
|
||||||
rinfo_.push_back({csinfo_.identity,
|
rinfo_.push_back({csinfo_.identity,
|
||||||
mkl_op_registry::GetMklOpName(csinfo_.identity),
|
mkl_op_registry::GetMklOpName(csinfo_.identity),
|
||||||
CopyAttrsIdentity, AlwaysRewrite, nullptr});
|
CopyAttrsIdentity, AlwaysRewrite, nullptr});
|
||||||
rinfo_.push_back({csinfo_.lrn,
|
rinfo_.push_back({csinfo_.lrn, mkl_op_registry::GetMklOpName(csinfo_.lrn),
|
||||||
mkl_op_registry::GetMklOpName(csinfo_.lrn),
|
|
||||||
CopyAttrsLRN, AlwaysRewrite, nullptr});
|
CopyAttrsLRN, AlwaysRewrite, nullptr});
|
||||||
rinfo_.push_back({csinfo_.lrn_grad,
|
rinfo_.push_back({csinfo_.lrn_grad,
|
||||||
mkl_op_registry::GetMklOpName(csinfo_.lrn_grad),
|
mkl_op_registry::GetMklOpName(csinfo_.lrn_grad),
|
||||||
@ -358,11 +357,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
rinfo_.push_back({csinfo_.maximum,
|
rinfo_.push_back({csinfo_.maximum,
|
||||||
mkl_op_registry::GetMklOpName(csinfo_.maximum),
|
mkl_op_registry::GetMklOpName(csinfo_.maximum),
|
||||||
CopyAttrsDataType, AlwaysRewrite, nullptr});
|
CopyAttrsDataType, AlwaysRewrite, nullptr});
|
||||||
rinfo_.push_back({csinfo_.mul,
|
rinfo_.push_back({csinfo_.mul, mkl_op_registry::GetMklOpName(csinfo_.mul),
|
||||||
mkl_op_registry::GetMklOpName(csinfo_.mul),
|
|
||||||
CopyAttrsDataType, AlwaysRewrite, nullptr});
|
CopyAttrsDataType, AlwaysRewrite, nullptr});
|
||||||
rinfo_.push_back({csinfo_.relu,
|
rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu),
|
||||||
mkl_op_registry::GetMklOpName(csinfo_.relu),
|
|
||||||
CopyAttrsDataType, AlwaysRewrite, nullptr});
|
CopyAttrsDataType, AlwaysRewrite, nullptr});
|
||||||
rinfo_.push_back({csinfo_.relu_grad,
|
rinfo_.push_back({csinfo_.relu_grad,
|
||||||
mkl_op_registry::GetMklOpName(csinfo_.relu_grad),
|
mkl_op_registry::GetMklOpName(csinfo_.relu_grad),
|
||||||
@ -373,8 +370,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
rinfo_.push_back({csinfo_.squared_difference,
|
rinfo_.push_back({csinfo_.squared_difference,
|
||||||
mkl_op_registry::GetMklOpName(csinfo_.squared_difference),
|
mkl_op_registry::GetMklOpName(csinfo_.squared_difference),
|
||||||
CopyAttrsDataType, AlwaysRewrite, nullptr});
|
CopyAttrsDataType, AlwaysRewrite, nullptr});
|
||||||
rinfo_.push_back({csinfo_.sub,
|
rinfo_.push_back({csinfo_.sub, mkl_op_registry::GetMklOpName(csinfo_.sub),
|
||||||
mkl_op_registry::GetMklOpName(csinfo_.sub),
|
|
||||||
CopyAttrsDataType, AlwaysRewrite, nullptr});
|
CopyAttrsDataType, AlwaysRewrite, nullptr});
|
||||||
|
|
||||||
// Add info about which ops to add workspace edge to and the slots.
|
// Add info about which ops to add workspace edge to and the slots.
|
||||||
@ -388,9 +384,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
biasaddgrad_matmul_context_ = {csinfo_.bias_add_grad, csinfo_.matmul,
|
biasaddgrad_matmul_context_ = {csinfo_.bias_add_grad, csinfo_.matmul,
|
||||||
IsBiasAddGradInMatMulContext};
|
IsBiasAddGradInMatMulContext};
|
||||||
|
|
||||||
biasaddgrad_conv2dwithbias_context_ = {csinfo_.bias_add_grad,
|
biasaddgrad_conv2dwithbias_context_ = {
|
||||||
csinfo_.mkl_conv2d_with_bias,
|
csinfo_.bias_add_grad, csinfo_.mkl_conv2d_with_bias,
|
||||||
IsBiasAddGradInConv2DWithBiasContext};
|
IsBiasAddGradInConv2DWithBiasContext};
|
||||||
|
|
||||||
cinfo_.push_back(&biasaddgrad_matmul_context_);
|
cinfo_.push_back(&biasaddgrad_matmul_context_);
|
||||||
cinfo_.push_back(&biasaddgrad_conv2dwithbias_context_);
|
cinfo_.push_back(&biasaddgrad_conv2dwithbias_context_);
|
||||||
@ -410,9 +406,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
|
|
||||||
/// Structure to specify the context information used in a node rewrite rule
|
/// Structure to specify the context information used in a node rewrite rule
|
||||||
typedef struct {
|
typedef struct {
|
||||||
string node; // Name of the node to be rewritten
|
string node; // Name of the node to be rewritten
|
||||||
string fwd; // Name of the node in the forward pass that this node
|
string fwd; // Name of the node in the forward pass that this node
|
||||||
// corresponds to
|
// corresponds to
|
||||||
std::function<bool(const Node*, const Node**, void* c)> context_match_fn;
|
std::function<bool(const Node*, const Node**, void* c)> context_match_fn;
|
||||||
} ContextInfo;
|
} ContextInfo;
|
||||||
|
|
||||||
@ -615,14 +611,13 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
std::vector<int32> ksize, strides;
|
std::vector<int32> ksize, strides;
|
||||||
CHECK_EQ(GetNodeAttr(n->def(), "ksize", &ksize).ok(), true);
|
CHECK_EQ(GetNodeAttr(n->def(), "ksize", &ksize).ok(), true);
|
||||||
CHECK_EQ(GetNodeAttr(n->def(), "strides", &strides).ok(), true);
|
CHECK_EQ(GetNodeAttr(n->def(), "strides", &strides).ok(), true);
|
||||||
CHECK_EQ(GetNodeAttr(n->def(), "data_format", &data_format_str).ok(),
|
CHECK_EQ(GetNodeAttr(n->def(), "data_format", &data_format_str).ok(), true);
|
||||||
true);
|
|
||||||
CHECK_EQ(FormatFromString(data_format_str, &data_format), true);
|
CHECK_EQ(FormatFromString(data_format_str, &data_format), true);
|
||||||
|
|
||||||
// Condition that specifies non-batch-wise and non-depth-wise pooling.
|
// Condition that specifies non-batch-wise and non-depth-wise pooling.
|
||||||
if (GetTensorDim(ksize, data_format, 'N') == 1 &&
|
if (GetTensorDim(ksize, data_format, 'N') == 1 &&
|
||||||
GetTensorDim(strides, data_format, 'N') == 1 &&
|
GetTensorDim(strides, data_format, 'N') == 1 &&
|
||||||
GetTensorDim(ksize, data_format, 'C') == 1 &&
|
GetTensorDim(ksize, data_format, 'C') == 1 &&
|
||||||
GetTensorDim(strides, data_format, 'C') == 1) {
|
GetTensorDim(strides, data_format, 'C') == 1) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -785,8 +780,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
for (const Edge* fe : first_inp_of_filter->out_edges()) {
|
for (const Edge* fe : first_inp_of_filter->out_edges()) {
|
||||||
if (fe->dst()->type_string() == csinfo_.mkl_conv2d_with_bias &&
|
if (fe->dst()->type_string() == csinfo_.mkl_conv2d_with_bias &&
|
||||||
fe->dst_input() == 0) {
|
fe->dst_input() == 0) {
|
||||||
VLOG(1) << "MklLayoutRewritePass: found "
|
VLOG(1) << "MklLayoutRewritePass: found " << fe->dst()->DebugString()
|
||||||
<< fe->dst()->DebugString()
|
|
||||||
<< " as the forward node for matching context, backward"
|
<< " as the forward node for matching context, backward"
|
||||||
<< " node is: " << n->DebugString();
|
<< " node is: " << n->DebugString();
|
||||||
*fwd_node = fe->dst();
|
*fwd_node = fe->dst();
|
||||||
@ -803,13 +797,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
//
|
//
|
||||||
// @return - true (if BiasAddGrad is associated with MatMul);
|
// @return - true (if BiasAddGrad is associated with MatMul);
|
||||||
// false otherwise.
|
// false otherwise.
|
||||||
static bool IsBiasAddGradInMatMulContext(const Node* n,
|
static bool IsBiasAddGradInMatMulContext(const Node* n, const Node** fwd_node,
|
||||||
const Node** fwd_node,
|
|
||||||
void* ci) {
|
void* ci) {
|
||||||
return (!IsBiasAddGradInConv2DWithBiasContext(n, fwd_node, ci));
|
return (!IsBiasAddGradInConv2DWithBiasContext(n, fwd_node, ci));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Rewrite rule that uses context-information for matching,
|
// Rewrite rule that uses context-information for matching,
|
||||||
// used in scenario 2.
|
// used in scenario 2.
|
||||||
//
|
//
|
||||||
@ -880,10 +872,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
// @output output_nodes - the list of new nodes creating Mkl tensors
|
// @output output_nodes - the list of new nodes creating Mkl tensors
|
||||||
//
|
//
|
||||||
// @return None
|
// @return None
|
||||||
void GetNodesProducingMklTensorList(std::unique_ptr<Graph>* g,
|
void GetNodesProducingMklTensorList(
|
||||||
Node* orig_node, const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
|
std::unique_ptr<Graph>* g, Node* orig_node,
|
||||||
int* input_idx, int list_length,
|
const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
|
||||||
std::vector<NodeBuilder::NodeOut>* output_nodes);
|
int* input_idx, int list_length,
|
||||||
|
std::vector<NodeBuilder::NodeOut>* output_nodes);
|
||||||
|
|
||||||
// Get a node that will feed an Mkl tensor to the new
|
// Get a node that will feed an Mkl tensor to the new
|
||||||
// node that we are constructing. The output node could be (1) 'n'
|
// node that we are constructing. The output node could be (1) 'n'
|
||||||
@ -900,7 +893,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
// will feed the tensor
|
// will feed the tensor
|
||||||
// @return None
|
// @return None
|
||||||
void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, Node* orig_node,
|
void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, Node* orig_node,
|
||||||
Node* n, int n_output_slot, Node** mkl_node, int* mkl_node_output_slot);
|
Node* n, int n_output_slot, Node** mkl_node,
|
||||||
|
int* mkl_node_output_slot);
|
||||||
|
|
||||||
// Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
|
// Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
|
||||||
// in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are
|
// in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are
|
||||||
@ -970,9 +964,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
|
|
||||||
MklLayoutRewritePass::ConstStringsInfo MklLayoutRewritePass::csinfo_;
|
MklLayoutRewritePass::ConstStringsInfo MklLayoutRewritePass::csinfo_;
|
||||||
MklLayoutRewritePass::ContextInfo
|
MklLayoutRewritePass::ContextInfo
|
||||||
MklLayoutRewritePass::biasaddgrad_conv2dwithbias_context_;
|
MklLayoutRewritePass::biasaddgrad_conv2dwithbias_context_;
|
||||||
MklLayoutRewritePass::ContextInfo
|
MklLayoutRewritePass::ContextInfo
|
||||||
MklLayoutRewritePass::biasaddgrad_matmul_context_;
|
MklLayoutRewritePass::biasaddgrad_matmul_context_;
|
||||||
std::vector<MklLayoutRewritePass::ContextInfo*> MklLayoutRewritePass::cinfo_;
|
std::vector<MklLayoutRewritePass::ContextInfo*> MklLayoutRewritePass::cinfo_;
|
||||||
|
|
||||||
// We register Mkl rewrite pass for phase 1 in post partitioning group.
|
// We register Mkl rewrite pass for phase 1 in post partitioning group.
|
||||||
@ -1041,13 +1035,13 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
|
|||||||
TensorShape dummy_shape({8});
|
TensorShape dummy_shape({8});
|
||||||
dummy_shape.AsProto(proto.mutable_tensor_shape());
|
dummy_shape.AsProto(proto.mutable_tensor_shape());
|
||||||
TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
|
TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
|
||||||
.Attr("value", proto)
|
.Attr("value", proto)
|
||||||
.Attr("dtype", dt)
|
.Attr("dtype", dt)
|
||||||
.Device(orig_node->def().device()) // We place this node on
|
.Device(orig_node->def().device()) // We place this node on
|
||||||
// the same device as the
|
// the same device as the
|
||||||
// device of the original
|
// device of the original
|
||||||
// node.
|
// node.
|
||||||
.Finalize(&**g, out));
|
.Finalize(&**g, out));
|
||||||
|
|
||||||
// If number of inputs to the original node is > 0, then we add
|
// If number of inputs to the original node is > 0, then we add
|
||||||
// control dependency between 1st input (index 0) of the original node and
|
// control dependency between 1st input (index 0) of the original node and
|
||||||
@ -1060,8 +1054,8 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
|
|||||||
// the same frame.
|
// the same frame.
|
||||||
if (orig_node->num_inputs() > 0) {
|
if (orig_node->num_inputs() > 0) {
|
||||||
Node* orig_input0 = nullptr;
|
Node* orig_input0 = nullptr;
|
||||||
TF_CHECK_OK(orig_node->input_node(0,
|
TF_CHECK_OK(
|
||||||
const_cast<const Node**>(&orig_input0)));
|
orig_node->input_node(0, const_cast<const Node**>(&orig_input0)));
|
||||||
CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out));
|
CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1069,11 +1063,9 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void MklLayoutRewritePass::GetNodesProducingMklTensorList(
|
void MklLayoutRewritePass::GetNodesProducingMklTensorList(
|
||||||
std::unique_ptr<Graph>* g,
|
std::unique_ptr<Graph>* g, Node* orig_node,
|
||||||
Node* orig_node,
|
const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx,
|
||||||
const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
|
int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) {
|
||||||
int* input_idx, int list_length,
|
|
||||||
std::vector<NodeBuilder::NodeOut>* output_nodes) {
|
|
||||||
CHECK_LT(*input_idx, inputs.size());
|
CHECK_LT(*input_idx, inputs.size());
|
||||||
CHECK_GT(list_length, 0);
|
CHECK_GT(list_length, 0);
|
||||||
CHECK_NOTNULL(output_nodes);
|
CHECK_NOTNULL(output_nodes);
|
||||||
@ -1090,8 +1082,8 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList(
|
|||||||
int mkl_node_output_slot = 0;
|
int mkl_node_output_slot = 0;
|
||||||
GetNodeProducingMklTensor(g, orig_node, n, slot, &mkl_node,
|
GetNodeProducingMklTensor(g, orig_node, n, slot, &mkl_node,
|
||||||
&mkl_node_output_slot);
|
&mkl_node_output_slot);
|
||||||
output_nodes->push_back(NodeBuilder::NodeOut(mkl_node,
|
output_nodes->push_back(
|
||||||
mkl_node_output_slot));
|
NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot));
|
||||||
(*input_idx)++;
|
(*input_idx)++;
|
||||||
list_length--;
|
list_length--;
|
||||||
}
|
}
|
||||||
@ -1101,9 +1093,9 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList(
|
|||||||
// node that we are constructing. An input node could be (1) 'n'
|
// node that we are constructing. An input node could be (1) 'n'
|
||||||
// if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
|
// if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
|
||||||
// if 'n' is not an Mkl layer.
|
// if 'n' is not an Mkl layer.
|
||||||
void MklLayoutRewritePass::GetNodeProducingMklTensor(std::unique_ptr<Graph>* g,
|
void MklLayoutRewritePass::GetNodeProducingMklTensor(
|
||||||
Node* orig_node, Node* n,
|
std::unique_ptr<Graph>* g, Node* orig_node, Node* n, int n_output_slot,
|
||||||
int n_output_slot, Node** mkl_node, int* mkl_node_output_slot) {
|
Node** mkl_node, int* mkl_node_output_slot) {
|
||||||
CHECK_NOTNULL(n);
|
CHECK_NOTNULL(n);
|
||||||
CHECK_NOTNULL(mkl_node);
|
CHECK_NOTNULL(mkl_node);
|
||||||
CHECK_NOTNULL(mkl_node_output_slot);
|
CHECK_NOTNULL(mkl_node_output_slot);
|
||||||
@ -1234,8 +1226,8 @@ int MklLayoutRewritePass::SetUpContiguousInputs(
|
|||||||
if (ArgIsList(arg)) {
|
if (ArgIsList(arg)) {
|
||||||
std::vector<NodeBuilder::NodeOut> new_node_inputs;
|
std::vector<NodeBuilder::NodeOut> new_node_inputs;
|
||||||
int N = GetTensorListLength(arg, old_node);
|
int N = GetTensorListLength(arg, old_node);
|
||||||
GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx,
|
GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx, N,
|
||||||
N, &new_node_inputs);
|
&new_node_inputs);
|
||||||
nb->Input(new_node_inputs);
|
nb->Input(new_node_inputs);
|
||||||
nn_slot_idx++;
|
nn_slot_idx++;
|
||||||
} else {
|
} else {
|
||||||
@ -1336,13 +1328,13 @@ void MklLayoutRewritePass::GetDummyWorkspaceTensorNode(
|
|||||||
TensorShape dummy_shape({1});
|
TensorShape dummy_shape({1});
|
||||||
dummy_shape.AsProto(proto.mutable_tensor_shape());
|
dummy_shape.AsProto(proto.mutable_tensor_shape());
|
||||||
TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
|
TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
|
||||||
.Attr("value", proto)
|
.Attr("value", proto)
|
||||||
.Attr("dtype", dt)
|
.Attr("dtype", dt)
|
||||||
.Device(orig_node->def().device()) // We place this node on
|
.Device(orig_node->def().device()) // We place this node on
|
||||||
// same the device as the
|
// same the device as the
|
||||||
// device of the original
|
// device of the original
|
||||||
// node.
|
// node.
|
||||||
.Finalize(&**g, out));
|
.Finalize(&**g, out));
|
||||||
|
|
||||||
// If number of inputs to the original node is > 0, then we add
|
// If number of inputs to the original node is > 0, then we add
|
||||||
// control dependency between 1st input (index 0) of the original node and
|
// control dependency between 1st input (index 0) of the original node and
|
||||||
@ -1355,8 +1347,8 @@ void MklLayoutRewritePass::GetDummyWorkspaceTensorNode(
|
|||||||
// the same frame.
|
// the same frame.
|
||||||
if (orig_node->num_inputs() > 0) {
|
if (orig_node->num_inputs() > 0) {
|
||||||
Node* orig_input0 = nullptr;
|
Node* orig_input0 = nullptr;
|
||||||
TF_CHECK_OK(orig_node->input_node(0,
|
TF_CHECK_OK(
|
||||||
const_cast<const Node**>(&orig_input0)));
|
orig_node->input_node(0, const_cast<const Node**>(&orig_input0)));
|
||||||
CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out));
|
CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1374,7 +1366,8 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
|
|||||||
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
|
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
|
||||||
for (auto ws : wsinfo_) {
|
for (auto ws : wsinfo_) {
|
||||||
if (orig_node->type_string() == ws.fwd_op &&
|
if (orig_node->type_string() == ws.fwd_op &&
|
||||||
mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) {
|
mkl_op_registry::IsMklOp(
|
||||||
|
mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) {
|
||||||
// If this op is a fwd op, then we need to check if there is an
|
// If this op is a fwd op, then we need to check if there is an
|
||||||
// edge from this node's fwd_slot to bwdop's bwd_slot. If there is
|
// edge from this node's fwd_slot to bwdop's bwd_slot. If there is
|
||||||
// an edge, then we just add an attribute on this node for setting
|
// an edge, then we just add an attribute on this node for setting
|
||||||
@ -1400,8 +1393,9 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
|
|||||||
nb->Attr("workspace_enabled", false);
|
nb->Attr("workspace_enabled", false);
|
||||||
}
|
}
|
||||||
} else if (orig_node->type_string() == ws.bwd_op &&
|
} else if (orig_node->type_string() == ws.bwd_op &&
|
||||||
mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(orig_node->type_string()),
|
mkl_op_registry::IsMklOp(
|
||||||
T)) {
|
mkl_op_registry::GetMklOpName(orig_node->type_string()),
|
||||||
|
T)) {
|
||||||
// If this op is a bwd op, then we need to add workspace edge and
|
// If this op is a bwd op, then we need to add workspace edge and
|
||||||
// it's Mkl tensor edge between its corresponding fwd op and this
|
// it's Mkl tensor edge between its corresponding fwd op and this
|
||||||
// op. Corresponding fwd op is specified in 'fwd_op' field of
|
// op. Corresponding fwd op is specified in 'fwd_op' field of
|
||||||
@ -1416,7 +1410,8 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
|
|||||||
if (e->src_output() == ws.fwd_slot &&
|
if (e->src_output() == ws.fwd_slot &&
|
||||||
// We would have rewritten the forward op, so we need to use
|
// We would have rewritten the forward op, so we need to use
|
||||||
// GetMklOpName call to get its Mkl name.
|
// GetMklOpName call to get its Mkl name.
|
||||||
e->src()->type_string() == mkl_op_registry::GetMklOpName(ws.fwd_op) &&
|
e->src()->type_string() ==
|
||||||
|
mkl_op_registry::GetMklOpName(ws.fwd_op) &&
|
||||||
e->dst_input() == ws.bwd_slot) {
|
e->dst_input() == ws.bwd_slot) {
|
||||||
nb->Attr("workspace_enabled", true);
|
nb->Attr("workspace_enabled", true);
|
||||||
CHECK_NOTNULL(ws_tensors);
|
CHECK_NOTNULL(ws_tensors);
|
||||||
@ -1593,7 +1588,7 @@ void MklLayoutRewritePass::CopyAttrsDataType(const Node* orig_node,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node,
|
void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node,
|
||||||
NodeBuilder* nb) {
|
NodeBuilder* nb) {
|
||||||
DataType T;
|
DataType T;
|
||||||
DataType Tshape;
|
DataType Tshape;
|
||||||
|
|
||||||
@ -1869,8 +1864,8 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* succ,
|
|||||||
if (e->IsControlEdge()) {
|
if (e->IsControlEdge()) {
|
||||||
CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst()));
|
CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst()));
|
||||||
} else {
|
} else {
|
||||||
CHECK_NOTNULL((*g)->AddEdge(new_node, e->src_output(), e->dst(),
|
CHECK_NOTNULL(
|
||||||
e->dst_input()));
|
(*g)->AddEdge(new_node, e->src_output(), e->dst(), e->dst_input()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1941,9 +1936,9 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
|
|||||||
// and leave BiasAddGrad as it is. But we check for this condition
|
// and leave BiasAddGrad as it is. But we check for this condition
|
||||||
// when we check for node rewrite rule. So we should not even come
|
// when we check for node rewrite rule. So we should not even come
|
||||||
// here for MatMul. So we will fail now.
|
// here for MatMul. So we will fail now.
|
||||||
return Status(
|
return Status(
|
||||||
error::Code::INVALID_ARGUMENT,
|
error::Code::INVALID_ARGUMENT,
|
||||||
"No rewrite is required for BiasAddGrad for MatMul context.");
|
"No rewrite is required for BiasAddGrad for MatMul context.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2012,9 +2007,10 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
|
|||||||
if (e->IsControlEdge()) {
|
if (e->IsControlEdge()) {
|
||||||
CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst()));
|
CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst()));
|
||||||
} else {
|
} else {
|
||||||
CHECK_NOTNULL((*g)->AddEdge(new_node, GetTensorDataIndex(e->src_output(),
|
CHECK_NOTNULL((*g)->AddEdge(
|
||||||
e->src()->num_outputs()),
|
new_node,
|
||||||
e->dst(), e->dst_input()));
|
GetTensorDataIndex(e->src_output(), e->src()->num_outputs()),
|
||||||
|
e->dst(), e->dst_input()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2070,7 +2066,8 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
|
|||||||
|
|
||||||
// BiasAddGrad is not an Mkl layer, so we make an exception for it.
|
// BiasAddGrad is not an Mkl layer, so we make an exception for it.
|
||||||
if (n->type_string() != csinfo_.bias_add_grad) {
|
if (n->type_string() != csinfo_.bias_add_grad) {
|
||||||
if (!mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()), T)) {
|
if (!mkl_op_registry::IsMklOp(
|
||||||
|
mkl_op_registry::GetMklOpName(n->type_string()), T)) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -2186,8 +2183,7 @@ bool RunMklLayoutRewritePass(std::unique_ptr<Graph>* g) {
|
|||||||
return MklLayoutRewritePass().RunPass(g);
|
return MklLayoutRewritePass().RunPass(g);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MklLayoutRewritePass::Run(
|
Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) {
|
||||||
const GraphOptimizationPassOptions& options) {
|
|
||||||
if (options.graph == nullptr && options.partition_graphs == nullptr) {
|
if (options.graph == nullptr && options.partition_graphs == nullptr) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -2215,7 +2211,7 @@ Status MklLayoutRewritePass::Run(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // INTEL_MKL_DNN
|
#else // INTEL_MKL_DNN
|
||||||
|
|
||||||
// This pass implements rewriting of graph to support following scenarios:
|
// This pass implements rewriting of graph to support following scenarios:
|
||||||
// (A) Merging nodes in the graph
|
// (A) Merging nodes in the graph
|
||||||
@ -2421,7 +2417,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
csinfo_.conv2d_grad_input = "Conv2DBackpropInput";
|
csinfo_.conv2d_grad_input = "Conv2DBackpropInput";
|
||||||
csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter";
|
csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter";
|
||||||
csinfo_.conv2d_grad_filter_with_bias =
|
csinfo_.conv2d_grad_filter_with_bias =
|
||||||
"__MklDummyConv2DBackpropFilterWithBias";
|
"__MklDummyConv2DBackpropFilterWithBias";
|
||||||
csinfo_.fused_batch_norm = "FusedBatchNorm";
|
csinfo_.fused_batch_norm = "FusedBatchNorm";
|
||||||
csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad";
|
csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad";
|
||||||
csinfo_.identity = "Identity";
|
csinfo_.identity = "Identity";
|
||||||
@ -2435,11 +2431,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
csinfo_.mkl_conv2d_grad_filter = "_MklConv2DBackpropFilter";
|
csinfo_.mkl_conv2d_grad_filter = "_MklConv2DBackpropFilter";
|
||||||
csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias";
|
csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias";
|
||||||
csinfo_.mkl_conv2d_grad_filter_with_bias =
|
csinfo_.mkl_conv2d_grad_filter_with_bias =
|
||||||
"_MklConv2DBackpropFilterWithBias";
|
"_MklConv2DBackpropFilterWithBias";
|
||||||
csinfo_.relu = "Relu";
|
csinfo_.relu = "Relu";
|
||||||
csinfo_.relu_grad = "ReluGrad";
|
csinfo_.relu_grad = "ReluGrad";
|
||||||
csinfo_.tanh = "Tanh";
|
csinfo_.tanh = "Tanh";
|
||||||
csinfo_.tanh_grad = "TanhGrad";
|
csinfo_.tanh_grad = "TanhGrad";
|
||||||
csinfo_.reshape = "Reshape";
|
csinfo_.reshape = "Reshape";
|
||||||
csinfo_.softmax = "Softmax";
|
csinfo_.softmax = "Softmax";
|
||||||
csinfo_.split = "Split";
|
csinfo_.split = "Split";
|
||||||
@ -2474,29 +2470,28 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
rinfo_.push_back({csinfo_.conv2d,
|
rinfo_.push_back({csinfo_.conv2d,
|
||||||
mkl_op_registry::GetMklOpName(csinfo_.conv2d),
|
mkl_op_registry::GetMklOpName(csinfo_.conv2d),
|
||||||
CopyAttrsConv2D, AlwaysRewrite});
|
CopyAttrsConv2D, AlwaysRewrite});
|
||||||
rinfo_.push_back({csinfo_.conv2d_with_bias,
|
rinfo_.push_back({csinfo_.conv2d_with_bias, csinfo_.mkl_conv2d_with_bias,
|
||||||
csinfo_.mkl_conv2d_with_bias,
|
|
||||||
CopyAttrsConv2D, AlwaysRewrite});
|
CopyAttrsConv2D, AlwaysRewrite});
|
||||||
rinfo_.push_back({csinfo_.conv2d_grad_filter,
|
rinfo_.push_back({csinfo_.conv2d_grad_filter,
|
||||||
mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_filter),
|
mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_filter),
|
||||||
CopyAttrsConv2D, AlwaysRewrite});
|
CopyAttrsConv2D, AlwaysRewrite});
|
||||||
rinfo_.push_back({csinfo_.conv2d_grad_filter_with_bias,
|
rinfo_.push_back({csinfo_.conv2d_grad_filter_with_bias,
|
||||||
csinfo_.mkl_conv2d_grad_filter_with_bias,
|
csinfo_.mkl_conv2d_grad_filter_with_bias, CopyAttrsConv2D,
|
||||||
CopyAttrsConv2D, AlwaysRewrite});
|
AlwaysRewrite});
|
||||||
rinfo_.push_back({csinfo_.conv2d_grad_input,
|
rinfo_.push_back({csinfo_.conv2d_grad_input,
|
||||||
mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input),
|
mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input),
|
||||||
CopyAttrsConv2D, AlwaysRewrite});
|
CopyAttrsConv2D, AlwaysRewrite});
|
||||||
rinfo_.push_back({csinfo_.fused_batch_norm,
|
rinfo_.push_back({csinfo_.fused_batch_norm,
|
||||||
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm),
|
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm),
|
||||||
CopyAttrsFusedBatchNorm, AlwaysRewrite});
|
CopyAttrsFusedBatchNorm, AlwaysRewrite});
|
||||||
rinfo_.push_back({csinfo_.fused_batch_norm_grad,
|
rinfo_.push_back(
|
||||||
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad),
|
{csinfo_.fused_batch_norm_grad,
|
||||||
CopyAttrsFusedBatchNorm, AlwaysRewrite});
|
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad),
|
||||||
|
CopyAttrsFusedBatchNorm, AlwaysRewrite});
|
||||||
rinfo_.push_back({csinfo_.identity,
|
rinfo_.push_back({csinfo_.identity,
|
||||||
mkl_op_registry::GetMklOpName(csinfo_.identity),
|
mkl_op_registry::GetMklOpName(csinfo_.identity),
|
||||||
CopyAttrsDataType, AlwaysRewrite});
|
CopyAttrsDataType, AlwaysRewrite});
|
||||||
rinfo_.push_back({csinfo_.lrn,
|
rinfo_.push_back({csinfo_.lrn, mkl_op_registry::GetMklOpName(csinfo_.lrn),
|
||||||
mkl_op_registry::GetMklOpName(csinfo_.lrn),
|
|
||||||
CopyAttrsLRN, AlwaysRewrite});
|
CopyAttrsLRN, AlwaysRewrite});
|
||||||
rinfo_.push_back({csinfo_.lrn_grad,
|
rinfo_.push_back({csinfo_.lrn_grad,
|
||||||
mkl_op_registry::GetMklOpName(csinfo_.lrn_grad),
|
mkl_op_registry::GetMklOpName(csinfo_.lrn_grad),
|
||||||
@ -2515,8 +2510,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
mkl_op_registry::GetMklOpName(csinfo_.mul),
|
mkl_op_registry::GetMklOpName(csinfo_.mul),
|
||||||
CopyAttrsDataType, AlwaysRewrite});
|
CopyAttrsDataType, AlwaysRewrite});
|
||||||
*/
|
*/
|
||||||
rinfo_.push_back({csinfo_.relu,
|
rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu),
|
||||||
mkl_op_registry::GetMklOpName(csinfo_.relu),
|
|
||||||
CopyAttrsDataType, AlwaysRewrite});
|
CopyAttrsDataType, AlwaysRewrite});
|
||||||
rinfo_.push_back({csinfo_.relu_grad,
|
rinfo_.push_back({csinfo_.relu_grad,
|
||||||
mkl_op_registry::GetMklOpName(csinfo_.relu_grad),
|
mkl_op_registry::GetMklOpName(csinfo_.relu_grad),
|
||||||
@ -2550,8 +2544,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
|
|
||||||
// Add a rule for merging nodes
|
// Add a rule for merging nodes
|
||||||
minfo_.push_back({csinfo_.conv2d, csinfo_.bias_add,
|
minfo_.push_back({csinfo_.conv2d, csinfo_.bias_add,
|
||||||
csinfo_.conv2d_with_bias,
|
csinfo_.conv2d_with_bias, GetConv2DOrBiasAdd});
|
||||||
GetConv2DOrBiasAdd});
|
|
||||||
|
|
||||||
minfo_.push_back({csinfo_.conv2d_grad_filter, csinfo_.bias_add_grad,
|
minfo_.push_back({csinfo_.conv2d_grad_filter, csinfo_.bias_add_grad,
|
||||||
csinfo_.conv2d_grad_filter_with_bias,
|
csinfo_.conv2d_grad_filter_with_bias,
|
||||||
@ -2846,9 +2839,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
|
|
||||||
// Default rewrite rule to be used in scenario 1 for rewrite.
|
// Default rewrite rule to be used in scenario 1 for rewrite.
|
||||||
// @return - true (since we want to always rewrite)
|
// @return - true (since we want to always rewrite)
|
||||||
static bool AlwaysRewrite(const Node* n) {
|
static bool AlwaysRewrite(const Node* n) { return true; }
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if we are performing pooling on depth or batch. If it is, then we
|
// Check if we are performing pooling on depth or batch. If it is, then we
|
||||||
// do not rewrite MaxPool node to Mkl version.
|
// do not rewrite MaxPool node to Mkl version.
|
||||||
@ -2862,14 +2853,13 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
std::vector<int32> ksize, strides;
|
std::vector<int32> ksize, strides;
|
||||||
CHECK_EQ(GetNodeAttr(n->def(), "ksize", &ksize).ok(), true);
|
CHECK_EQ(GetNodeAttr(n->def(), "ksize", &ksize).ok(), true);
|
||||||
CHECK_EQ(GetNodeAttr(n->def(), "strides", &strides).ok(), true);
|
CHECK_EQ(GetNodeAttr(n->def(), "strides", &strides).ok(), true);
|
||||||
CHECK_EQ(GetNodeAttr(n->def(), "data_format", &data_format_str).ok(),
|
CHECK_EQ(GetNodeAttr(n->def(), "data_format", &data_format_str).ok(), true);
|
||||||
true);
|
|
||||||
CHECK_EQ(FormatFromString(data_format_str, &data_format), true);
|
CHECK_EQ(FormatFromString(data_format_str, &data_format), true);
|
||||||
|
|
||||||
// Condition that specifies non-batch-wise and non-depth-wise pooling.
|
// Condition that specifies non-batch-wise and non-depth-wise pooling.
|
||||||
if (GetTensorDim(ksize, data_format, 'N') == 1 &&
|
if (GetTensorDim(ksize, data_format, 'N') == 1 &&
|
||||||
GetTensorDim(strides, data_format, 'N') == 1 &&
|
GetTensorDim(strides, data_format, 'N') == 1 &&
|
||||||
GetTensorDim(ksize, data_format, 'C') == 1 &&
|
GetTensorDim(ksize, data_format, 'C') == 1 &&
|
||||||
GetTensorDim(strides, data_format, 'C') == 1) {
|
GetTensorDim(strides, data_format, 'C') == 1) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -2941,10 +2931,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
// @output output_nodes - the list of new nodes creating Mkl tensors
|
// @output output_nodes - the list of new nodes creating Mkl tensors
|
||||||
//
|
//
|
||||||
// @return None
|
// @return None
|
||||||
void GetNodesProducingMklTensorList(std::unique_ptr<Graph>* g,
|
void GetNodesProducingMklTensorList(
|
||||||
Node* orig_node, const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
|
std::unique_ptr<Graph>* g, Node* orig_node,
|
||||||
int* input_idx, int list_length,
|
const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
|
||||||
std::vector<NodeBuilder::NodeOut>* output_nodes);
|
int* input_idx, int list_length,
|
||||||
|
std::vector<NodeBuilder::NodeOut>* output_nodes);
|
||||||
|
|
||||||
// Get a node that will feed an Mkl tensor to the new
|
// Get a node that will feed an Mkl tensor to the new
|
||||||
// node that we are constructing. The output node could be (1) 'n'
|
// node that we are constructing. The output node could be (1) 'n'
|
||||||
@ -2961,7 +2952,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
// will feed the tensor
|
// will feed the tensor
|
||||||
// @return None
|
// @return None
|
||||||
void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, Node* orig_node,
|
void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, Node* orig_node,
|
||||||
Node* n, int n_output_slot, Node** mkl_node, int* mkl_node_output_slot);
|
Node* n, int n_output_slot, Node** mkl_node,
|
||||||
|
int* mkl_node_output_slot);
|
||||||
|
|
||||||
// Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
|
// Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
|
||||||
// in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are
|
// in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are
|
||||||
@ -3096,13 +3088,13 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
|
|||||||
TensorShape dummy_shape({8});
|
TensorShape dummy_shape({8});
|
||||||
dummy_shape.AsProto(proto.mutable_tensor_shape());
|
dummy_shape.AsProto(proto.mutable_tensor_shape());
|
||||||
TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
|
TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
|
||||||
.Attr("value", proto)
|
.Attr("value", proto)
|
||||||
.Attr("dtype", dt)
|
.Attr("dtype", dt)
|
||||||
.Device(orig_node->def().device()) // We place this node on
|
.Device(orig_node->def().device()) // We place this node on
|
||||||
// the same device as the
|
// the same device as the
|
||||||
// device of the original
|
// device of the original
|
||||||
// node.
|
// node.
|
||||||
.Finalize(&**g, out));
|
.Finalize(&**g, out));
|
||||||
|
|
||||||
// If number of inputs to the original node is > 0, then we add
|
// If number of inputs to the original node is > 0, then we add
|
||||||
// control dependency between 1st input (index 0) of the original node and
|
// control dependency between 1st input (index 0) of the original node and
|
||||||
@ -3115,8 +3107,8 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
|
|||||||
// the same frame.
|
// the same frame.
|
||||||
if (orig_node->num_inputs() > 0) {
|
if (orig_node->num_inputs() > 0) {
|
||||||
Node* orig_input0 = nullptr;
|
Node* orig_input0 = nullptr;
|
||||||
TF_CHECK_OK(orig_node->input_node(0,
|
TF_CHECK_OK(
|
||||||
const_cast<const Node**>(&orig_input0)));
|
orig_node->input_node(0, const_cast<const Node**>(&orig_input0)));
|
||||||
// Allow duplicate while adding control edge as it would fail (return
|
// Allow duplicate while adding control edge as it would fail (return
|
||||||
// NULL) if we try to add duplicate edge.
|
// NULL) if we try to add duplicate edge.
|
||||||
CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out, true));
|
CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out, true));
|
||||||
@ -3126,11 +3118,9 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void MklLayoutRewritePass::GetNodesProducingMklTensorList(
|
void MklLayoutRewritePass::GetNodesProducingMklTensorList(
|
||||||
std::unique_ptr<Graph>* g,
|
std::unique_ptr<Graph>* g, Node* orig_node,
|
||||||
Node* orig_node,
|
const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx,
|
||||||
const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
|
int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) {
|
||||||
int* input_idx, int list_length,
|
|
||||||
std::vector<NodeBuilder::NodeOut>* output_nodes) {
|
|
||||||
CHECK_LT(*input_idx, inputs.size());
|
CHECK_LT(*input_idx, inputs.size());
|
||||||
CHECK_GT(list_length, 0);
|
CHECK_GT(list_length, 0);
|
||||||
CHECK_NOTNULL(output_nodes);
|
CHECK_NOTNULL(output_nodes);
|
||||||
@ -3147,8 +3137,8 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList(
|
|||||||
int mkl_node_output_slot = 0;
|
int mkl_node_output_slot = 0;
|
||||||
GetNodeProducingMklTensor(g, orig_node, n, slot, &mkl_node,
|
GetNodeProducingMklTensor(g, orig_node, n, slot, &mkl_node,
|
||||||
&mkl_node_output_slot);
|
&mkl_node_output_slot);
|
||||||
output_nodes->push_back(NodeBuilder::NodeOut(mkl_node,
|
output_nodes->push_back(
|
||||||
mkl_node_output_slot));
|
NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot));
|
||||||
(*input_idx)++;
|
(*input_idx)++;
|
||||||
list_length--;
|
list_length--;
|
||||||
}
|
}
|
||||||
@ -3158,9 +3148,9 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList(
|
|||||||
// node that we are constructing. An input node could be (1) 'n'
|
// node that we are constructing. An input node could be (1) 'n'
|
||||||
// if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
|
// if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
|
||||||
// if 'n' is not an Mkl layer.
|
// if 'n' is not an Mkl layer.
|
||||||
void MklLayoutRewritePass::GetNodeProducingMklTensor(std::unique_ptr<Graph>* g,
|
void MklLayoutRewritePass::GetNodeProducingMklTensor(
|
||||||
Node* orig_node, Node* n,
|
std::unique_ptr<Graph>* g, Node* orig_node, Node* n, int n_output_slot,
|
||||||
int n_output_slot, Node** mkl_node, int* mkl_node_output_slot) {
|
Node** mkl_node, int* mkl_node_output_slot) {
|
||||||
CHECK_NOTNULL(n);
|
CHECK_NOTNULL(n);
|
||||||
CHECK_NOTNULL(mkl_node);
|
CHECK_NOTNULL(mkl_node);
|
||||||
CHECK_NOTNULL(mkl_node_output_slot);
|
CHECK_NOTNULL(mkl_node_output_slot);
|
||||||
@ -3292,8 +3282,8 @@ int MklLayoutRewritePass::SetUpContiguousInputs(
|
|||||||
if (ArgIsList(arg)) {
|
if (ArgIsList(arg)) {
|
||||||
std::vector<NodeBuilder::NodeOut> new_node_inputs;
|
std::vector<NodeBuilder::NodeOut> new_node_inputs;
|
||||||
int N = GetTensorListLength(arg, old_node);
|
int N = GetTensorListLength(arg, old_node);
|
||||||
GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx,
|
GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx, N,
|
||||||
N, &new_node_inputs);
|
&new_node_inputs);
|
||||||
nb->Input(new_node_inputs);
|
nb->Input(new_node_inputs);
|
||||||
nn_slot_idx++;
|
nn_slot_idx++;
|
||||||
} else {
|
} else {
|
||||||
@ -3394,13 +3384,13 @@ void MklLayoutRewritePass::GetDummyWorkspaceTensorNode(
|
|||||||
TensorShape dummy_shape({1});
|
TensorShape dummy_shape({1});
|
||||||
dummy_shape.AsProto(proto.mutable_tensor_shape());
|
dummy_shape.AsProto(proto.mutable_tensor_shape());
|
||||||
TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
|
TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
|
||||||
.Attr("value", proto)
|
.Attr("value", proto)
|
||||||
.Attr("dtype", dt)
|
.Attr("dtype", dt)
|
||||||
.Device(orig_node->def().device()) // We place this node on
|
.Device(orig_node->def().device()) // We place this node on
|
||||||
// same the device as the
|
// same the device as the
|
||||||
// device of the original
|
// device of the original
|
||||||
// node.
|
// node.
|
||||||
.Finalize(&**g, out));
|
.Finalize(&**g, out));
|
||||||
|
|
||||||
// If number of inputs to the original node is > 0, then we add
|
// If number of inputs to the original node is > 0, then we add
|
||||||
// control dependency between 1st input (index 0) of the original node and
|
// control dependency between 1st input (index 0) of the original node and
|
||||||
@ -3413,8 +3403,8 @@ void MklLayoutRewritePass::GetDummyWorkspaceTensorNode(
|
|||||||
// the same frame.
|
// the same frame.
|
||||||
if (orig_node->num_inputs() > 0) {
|
if (orig_node->num_inputs() > 0) {
|
||||||
Node* orig_input0 = nullptr;
|
Node* orig_input0 = nullptr;
|
||||||
TF_CHECK_OK(orig_node->input_node(0,
|
TF_CHECK_OK(
|
||||||
const_cast<const Node**>(&orig_input0)));
|
orig_node->input_node(0, const_cast<const Node**>(&orig_input0)));
|
||||||
// Allow duplicate while adding control edge as it would fail (return
|
// Allow duplicate while adding control edge as it would fail (return
|
||||||
// NULL) if we try to add duplicate edge.
|
// NULL) if we try to add duplicate edge.
|
||||||
CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out, true));
|
CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out, true));
|
||||||
@ -3434,8 +3424,8 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
|
|||||||
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
|
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
|
||||||
for (auto ws : wsinfo_) {
|
for (auto ws : wsinfo_) {
|
||||||
if (orig_node->type_string() == ws.fwd_op &&
|
if (orig_node->type_string() == ws.fwd_op &&
|
||||||
mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(
|
mkl_op_registry::IsMklOp(
|
||||||
orig_node->type_string()), T)) {
|
mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) {
|
||||||
// If this op is a fwd op, then we need to check if there is an
|
// If this op is a fwd op, then we need to check if there is an
|
||||||
// edge from this node's fwd_slot to bwdop's bwd_slot. If there is
|
// edge from this node's fwd_slot to bwdop's bwd_slot. If there is
|
||||||
// an edge, then we just add an attribute on this node for setting
|
// an edge, then we just add an attribute on this node for setting
|
||||||
@ -3461,8 +3451,9 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
|
|||||||
nb->Attr("workspace_enabled", false);
|
nb->Attr("workspace_enabled", false);
|
||||||
}
|
}
|
||||||
} else if (orig_node->type_string() == ws.bwd_op &&
|
} else if (orig_node->type_string() == ws.bwd_op &&
|
||||||
mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(
|
mkl_op_registry::IsMklOp(
|
||||||
orig_node->type_string()), T)) {
|
mkl_op_registry::GetMklOpName(orig_node->type_string()),
|
||||||
|
T)) {
|
||||||
// If this op is a bwd op, then we need to add workspace edge and
|
// If this op is a bwd op, then we need to add workspace edge and
|
||||||
// it's Mkl tensor edge between its corresponding fwd op and this
|
// it's Mkl tensor edge between its corresponding fwd op and this
|
||||||
// op. Corresponding fwd op is specified in 'fwd_op' field of
|
// op. Corresponding fwd op is specified in 'fwd_op' field of
|
||||||
@ -3477,8 +3468,8 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
|
|||||||
if (e->src_output() == ws.fwd_slot &&
|
if (e->src_output() == ws.fwd_slot &&
|
||||||
// We would have rewritten the forward op, so we need to use
|
// We would have rewritten the forward op, so we need to use
|
||||||
// GetMklOpName call to get its Mkl name.
|
// GetMklOpName call to get its Mkl name.
|
||||||
e->src()->type_string() == mkl_op_registry::GetMklOpName(
|
e->src()->type_string() ==
|
||||||
ws.fwd_op) &&
|
mkl_op_registry::GetMklOpName(ws.fwd_op) &&
|
||||||
e->dst_input() == ws.bwd_slot) {
|
e->dst_input() == ws.bwd_slot) {
|
||||||
nb->Attr("workspace_enabled", true);
|
nb->Attr("workspace_enabled", true);
|
||||||
CHECK_NOTNULL(ws_tensors);
|
CHECK_NOTNULL(ws_tensors);
|
||||||
@ -3645,7 +3636,7 @@ void MklLayoutRewritePass::CopyAttrsDataType(const Node* orig_node,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node,
|
void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node,
|
||||||
NodeBuilder* nb) {
|
NodeBuilder* nb) {
|
||||||
DataType T;
|
DataType T;
|
||||||
DataType Tshape;
|
DataType Tshape;
|
||||||
|
|
||||||
@ -3776,8 +3767,9 @@ Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g,
|
|||||||
Node* m, Node* n) {
|
Node* m, Node* n) {
|
||||||
CHECK_EQ(((m->type_string() == csinfo_.bias_add &&
|
CHECK_EQ(((m->type_string() == csinfo_.bias_add &&
|
||||||
n->type_string() == csinfo_.conv2d)) ||
|
n->type_string() == csinfo_.conv2d)) ||
|
||||||
((n->type_string() == csinfo_.bias_add &&
|
((n->type_string() == csinfo_.bias_add &&
|
||||||
m->type_string() == csinfo_.conv2d)), true);
|
m->type_string() == csinfo_.conv2d)),
|
||||||
|
true);
|
||||||
|
|
||||||
// If 'm' is BiasAdd, then 'n' is Conv2D. Since Conv2D feeds BiasAdd,
|
// If 'm' is BiasAdd, then 'n' is Conv2D. Since Conv2D feeds BiasAdd,
|
||||||
// BiasAdd is successor node, and Conv2D predecessor node.
|
// BiasAdd is successor node, and Conv2D predecessor node.
|
||||||
@ -3796,8 +3788,7 @@ Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g,
|
|||||||
TF_CHECK_OK(GetNodeAttr(pred->def(), "strides", &strides));
|
TF_CHECK_OK(GetNodeAttr(pred->def(), "strides", &strides));
|
||||||
TF_CHECK_OK(GetNodeAttr(pred->def(), "data_format", &data_format_pred));
|
TF_CHECK_OK(GetNodeAttr(pred->def(), "data_format", &data_format_pred));
|
||||||
TF_CHECK_OK(GetNodeAttr(succ->def(), "data_format", &data_format_succ));
|
TF_CHECK_OK(GetNodeAttr(succ->def(), "data_format", &data_format_succ));
|
||||||
TF_CHECK_OK(
|
TF_CHECK_OK(GetNodeAttr(pred->def(), "use_cudnn_on_gpu", &use_cudnn_on_gnu));
|
||||||
GetNodeAttr(pred->def(), "use_cudnn_on_gpu", &use_cudnn_on_gnu));
|
|
||||||
// We check to ensure that data formats of both succ and pred are same.
|
// We check to ensure that data formats of both succ and pred are same.
|
||||||
// We expect them to be same, so we can enforce this as assert.
|
// We expect them to be same, so we can enforce this as assert.
|
||||||
// But assert can be too strict, so we enforce this as a check.
|
// But assert can be too strict, so we enforce this as a check.
|
||||||
@ -3900,8 +3891,8 @@ Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g,
|
|||||||
// BiasAdd has only 1 output (at slot 0) and merged node also has only 1
|
// BiasAdd has only 1 output (at slot 0) and merged node also has only 1
|
||||||
// output (at slot 0).
|
// output (at slot 0).
|
||||||
const int kConv2DWithBiasOutputSlot = 0;
|
const int kConv2DWithBiasOutputSlot = 0;
|
||||||
CHECK_NOTNULL((*g)->AddEdge(new_node, kConv2DWithBiasOutputSlot,
|
CHECK_NOTNULL((*g)->AddEdge(new_node, kConv2DWithBiasOutputSlot, e->dst(),
|
||||||
e->dst(), e->dst_input()));
|
e->dst_input()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3924,8 +3915,9 @@ Status MklLayoutRewritePass::MergeConv2DBackpropFilterWithBiasAddGrad(
|
|||||||
std::unique_ptr<Graph>* g, Node* m, Node* n) {
|
std::unique_ptr<Graph>* g, Node* m, Node* n) {
|
||||||
CHECK_EQ(((m->type_string() == csinfo_.bias_add_grad &&
|
CHECK_EQ(((m->type_string() == csinfo_.bias_add_grad &&
|
||||||
n->type_string() == csinfo_.conv2d_grad_filter)) ||
|
n->type_string() == csinfo_.conv2d_grad_filter)) ||
|
||||||
((n->type_string() == csinfo_.bias_add_grad &&
|
((n->type_string() == csinfo_.bias_add_grad &&
|
||||||
m->type_string() == csinfo_.conv2d_grad_filter)), true);
|
m->type_string() == csinfo_.conv2d_grad_filter)),
|
||||||
|
true);
|
||||||
|
|
||||||
// If 'm' is BiasAddGrad, then 'n' is BackpropFilter.
|
// If 'm' is BiasAddGrad, then 'n' is BackpropFilter.
|
||||||
Node* badd = m->type_string() == csinfo_.bias_add_grad ? m : n;
|
Node* badd = m->type_string() == csinfo_.bias_add_grad ? m : n;
|
||||||
@ -4132,9 +4124,10 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
|
|||||||
// NULL) if we try to add duplicate edge.
|
// NULL) if we try to add duplicate edge.
|
||||||
CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true));
|
CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true));
|
||||||
} else {
|
} else {
|
||||||
CHECK_NOTNULL((*g)->AddEdge(new_node, GetTensorDataIndex(e->src_output(),
|
CHECK_NOTNULL((*g)->AddEdge(
|
||||||
e->src()->num_outputs()),
|
new_node,
|
||||||
e->dst(), e->dst_input()));
|
GetTensorDataIndex(e->src_output(), e->src()->num_outputs()),
|
||||||
|
e->dst(), e->dst_input()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -4166,9 +4159,9 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
|
|||||||
// names.
|
// names.
|
||||||
if (n->type_string() != csinfo_.conv2d_with_bias &&
|
if (n->type_string() != csinfo_.conv2d_with_bias &&
|
||||||
n->type_string() != csinfo_.conv2d_grad_filter_with_bias &&
|
n->type_string() != csinfo_.conv2d_grad_filter_with_bias &&
|
||||||
!mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(
|
!mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()),
|
||||||
n->type_string()), T)) {
|
T)) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
// For elementwise node, we reuse the Eigen implementation and pass the MKL
|
// For elementwise node, we reuse the Eigen implementation and pass the MKL
|
||||||
@ -4184,29 +4177,30 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
|
|||||||
// eigen code to reduce cross-library dependency.
|
// eigen code to reduce cross-library dependency.
|
||||||
VLOG(1) << "ELEMENTWISE: checking op: " << n->type_string();
|
VLOG(1) << "ELEMENTWISE: checking op: " << n->type_string();
|
||||||
if (mkl_op_registry::IsMklElementWiseOp(
|
if (mkl_op_registry::IsMklElementWiseOp(
|
||||||
mkl_op_registry::GetMklOpName(n->type_string()), T) ||
|
mkl_op_registry::GetMklOpName(n->type_string()), T) ||
|
||||||
n->type_string().find("Identity") != string::npos) {
|
n->type_string().find("Identity") != string::npos) {
|
||||||
VLOG(1) << "ELEMENTWISE: op is elementwise: " << n->type_string();
|
VLOG(1) << "ELEMENTWISE: op is elementwise: " << n->type_string();
|
||||||
bool incoming_mkl_edge = false;
|
bool incoming_mkl_edge = false;
|
||||||
int num_parent = 0;
|
int num_parent = 0;
|
||||||
for (auto parent : n->in_edges()) {
|
for (auto parent : n->in_edges()) {
|
||||||
if (mkl_op_registry::IsMklOp(parent->src()->type_string(), T)) {
|
if (mkl_op_registry::IsMklOp(parent->src()->type_string(), T)) {
|
||||||
VLOG(1) << "ELEMENTWISE: parent " << num_parent++ << " is MKL op: "
|
VLOG(1) << "ELEMENTWISE: parent " << num_parent++
|
||||||
<< parent->src()->type_string();
|
<< " is MKL op: " << parent->src()->type_string();
|
||||||
incoming_mkl_edge = true;
|
incoming_mkl_edge = true;
|
||||||
break;
|
break;
|
||||||
} else {
|
} else {
|
||||||
VLOG(1) << "ELEMENTWISE: parent " << num_parent++ << " is NON-MKL op: "
|
VLOG(1) << "ELEMENTWISE: parent " << num_parent++
|
||||||
<< parent->src()->type_string();
|
<< " is NON-MKL op: " << parent->src()->type_string();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (incoming_mkl_edge == false) {
|
if (incoming_mkl_edge == false) {
|
||||||
VLOG(1) << "ELEMENTWISE: Skipping replacement of elementwise node which has no MKL "
|
VLOG(1) << "ELEMENTWISE: Skipping replacement of elementwise node which "
|
||||||
|
"has no MKL "
|
||||||
"parents.";
|
"parents.";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
} else {
|
} else {
|
||||||
VLOG(1) << "ELEMENTWISE: Replacing elementwise node " << n->type_string() <<
|
VLOG(1) << "ELEMENTWISE: Replacing elementwise node " << n->type_string()
|
||||||
" which has MKL parents";
|
<< " which has MKL parents";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -4214,8 +4208,7 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
|
|||||||
// for this op, then we rewrite it to Mkl op.
|
// for this op, then we rewrite it to Mkl op.
|
||||||
// Find matching RewriteInfo and then check that rewrite rule applies.
|
// Find matching RewriteInfo and then check that rewrite rule applies.
|
||||||
for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) {
|
for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) {
|
||||||
if (n->type_string().compare(ri->name) == 0 &&
|
if (n->type_string().compare(ri->name) == 0 && ri->rewrite_rule(n)) {
|
||||||
ri->rewrite_rule(n)) {
|
|
||||||
return &*ri;
|
return &*ri;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -4297,8 +4290,7 @@ bool RunMklLayoutRewritePass(std::unique_ptr<Graph>* g) {
|
|||||||
return MklLayoutRewritePass().RunPass(g);
|
return MklLayoutRewritePass().RunPass(g);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MklLayoutRewritePass::Run(
|
Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) {
|
||||||
const GraphOptimizationPassOptions& options) {
|
|
||||||
if (options.graph == nullptr && options.partition_graphs == nullptr) {
|
if (options.graph == nullptr && options.partition_graphs == nullptr) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -125,8 +125,10 @@ REGISTER_OP("InputList").Output("o: N * float").Attr("N: int").SetIsStateful();
|
|||||||
REGISTER_OP("HalfInput").Output("o: half").SetIsStateful();
|
REGISTER_OP("HalfInput").Output("o: half").SetIsStateful();
|
||||||
REGISTER_OP("Int32Input").Output("o: int32").SetIsStateful();
|
REGISTER_OP("Int32Input").Output("o: int32").SetIsStateful();
|
||||||
REGISTER_OP("_MklInput").Output("o: uint8").SetIsStateful();
|
REGISTER_OP("_MklInput").Output("o: uint8").SetIsStateful();
|
||||||
REGISTER_OP("_MklInput2").Output("o: uint8")
|
REGISTER_OP("_MklInput2")
|
||||||
.Output("o1: uint8").SetIsStateful();
|
.Output("o: uint8")
|
||||||
|
.Output("o1: uint8")
|
||||||
|
.SetIsStateful();
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////
|
||||||
// Unit tests related to node merge optiimization
|
// Unit tests related to node merge optiimization
|
||||||
@ -498,7 +500,6 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative2) {
|
|||||||
"M->I:3;N->D:4;N->G:4;N->I:4;O->D:5;O->G:5;O->I:5");
|
"M->I:3;N->D:4;N->G:4;N->I:4;O->D:5;O->G:5;O->I:5");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter only
|
// BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter only
|
||||||
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_BpropFilter_Positive) {
|
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_BpropFilter_Positive) {
|
||||||
InitGraph(
|
InitGraph(
|
||||||
@ -874,11 +875,12 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Basic) {
|
|||||||
" input: ['A', 'B:0', 'B:1']}"
|
" input: ['A', 'B:0', 'B:1']}"
|
||||||
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
" input: ['C', 'D'] }");
|
" input: ['C', 'D'] }");
|
||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(
|
||||||
"A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);"
|
DoMklLayoutOptimizationPass(),
|
||||||
"DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;A:control->DMT/_0:control;"
|
"A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);"
|
||||||
"A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;"
|
"DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;A:control->DMT/_0:control;"
|
||||||
"B:1->D:2;C->E;D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
|
"A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;"
|
||||||
|
"B:1->D:2;C->E;D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Concat with 2 Mkl layers feeding it
|
// Concat with 2 Mkl layers feeding it
|
||||||
@ -1273,7 +1275,8 @@ TEST_F(MklLayoutPassTest, MaxPoolLRN_Positive) {
|
|||||||
"node { name: 'H' op: 'Input'}"
|
"node { name: 'H' op: 'Input'}"
|
||||||
"node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
"node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
" input: ['H', 'G'] }");
|
" input: ['H', 'G'] }");
|
||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(
|
||||||
|
DoMklLayoutOptimizationPass(),
|
||||||
"A(Input);B(_MklLRN);C(_MklMaxPool);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
"A(Input);B(_MklLRN);C(_MklMaxPool);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
||||||
"DMT/_2(Const);E(_MklMaxPoolGrad);F(Input);G(_MklLRNGrad);H(Input);"
|
"DMT/_2(Const);E(_MklMaxPoolGrad);F(Input);G(_MklLRNGrad);H(Input);"
|
||||||
"I(Zeta)|A->B;A:control->DMT/_0:control;B->C;B->E;B->G:2;B:1->G:3;"
|
"I(Zeta)|A->B;A:control->DMT/_0:control;B->C;B->E;B->G:2;B:1->G:3;"
|
||||||
@ -1640,7 +1643,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_DeviceTest) {
|
|||||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||||
" input: ['A', 'B']}"
|
" input: ['A', 'B']}"
|
||||||
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
" input: ['B', 'C'] }", kGPUDevice);
|
" input: ['B', 'C'] }",
|
||||||
|
kGPUDevice);
|
||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
"A(Input);B(Input);C(Conv2D);D(Zeta)|A->C;B->C:1;B->D;C->D:1");
|
"A(Input);B(Input);C(Conv2D);D(Zeta)|A->C;B->C:1;B->D;C->D:1");
|
||||||
}
|
}
|
||||||
@ -1666,7 +1670,8 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_DeviceTest) {
|
|||||||
"node { name: 'F' op: 'BiasAddGrad'"
|
"node { name: 'F' op: 'BiasAddGrad'"
|
||||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
" input: ['E'] }", kGPUDevice);
|
" input: ['E'] }",
|
||||||
|
kGPUDevice);
|
||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
"A(Input);B(Input);C(Input);D(_MklConv2DWithBias);"
|
"A(Input);B(Input);C(Input);D(_MklConv2DWithBias);"
|
||||||
"E(Zeta);F(BiasAddGrad);M(_MklInput);N(_MklInput);"
|
"E(Zeta);F(BiasAddGrad);M(_MklInput);N(_MklInput);"
|
||||||
@ -1687,7 +1692,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradFilter_DeviceTest) {
|
|||||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||||
" input: ['A', 'B', 'C']}"
|
" input: ['A', 'B', 'C']}"
|
||||||
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
" input: ['A', 'D'] }", kGPUDevice);
|
" input: ['A', 'D'] }",
|
||||||
|
kGPUDevice);
|
||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
"A(Input);B(Int32Input);C(Input);D(Conv2DBackpropFilter);E(Zeta)|"
|
"A(Input);B(Int32Input);C(Input);D(Conv2DBackpropFilter);E(Zeta)|"
|
||||||
"A->D;A->E;B->D:1;C->D:2;D->E:1");
|
"A->D;A->E;B->D:1;C->D:2;D->E:1");
|
||||||
@ -1700,7 +1706,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Relu_DeviceTest) {
|
|||||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
" input: ['A'] }"
|
" input: ['A'] }"
|
||||||
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
" input: ['A', 'B'] }", kGPUDevice);
|
" input: ['A', 'B'] }",
|
||||||
|
kGPUDevice);
|
||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
"A(Input);B(Relu);C(Zeta)|A->B;A->C;B->C:1");
|
"A(Input);B(Relu);C(Zeta)|A->B;A->C;B->C:1");
|
||||||
}
|
}
|
||||||
@ -1713,7 +1720,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ReluGrad_DeviceTest) {
|
|||||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
" input: ['A', 'B'] }"
|
" input: ['A', 'B'] }"
|
||||||
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
" input: ['A', 'C'] }", kGPUDevice);
|
" input: ['A', 'C'] }",
|
||||||
|
kGPUDevice);
|
||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
"A(Input);B(Input);C(ReluGrad);D(Zeta)|A->C;A->D;B->C:1;C->D:1");
|
"A(Input);B(Input);C(ReluGrad);D(Zeta)|A->C;A->D;B->C:1;C->D:1");
|
||||||
}
|
}
|
||||||
@ -1729,7 +1737,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_MaxPool_DeviceTest) {
|
|||||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
" input: ['A'] }"
|
" input: ['A'] }"
|
||||||
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
" input: ['A', 'B'] }", kGPUDevice);
|
" input: ['A', 'B'] }",
|
||||||
|
kGPUDevice);
|
||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
"A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
|
"A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
|
||||||
}
|
}
|
||||||
@ -1745,7 +1754,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_DeviceTest) {
|
|||||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
" input: ['A'] }"
|
" input: ['A'] }"
|
||||||
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
" input: ['A', 'B'] }", kGPUDevice);
|
" input: ['A', 'B'] }",
|
||||||
|
kGPUDevice);
|
||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
"A(Input);B(AvgPool);C(Zeta)|A->B;A->C;B->C:1");
|
"A(Input);B(AvgPool);C(Zeta)|A->B;A->C;B->C:1");
|
||||||
}
|
}
|
||||||
@ -1766,7 +1776,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_DeviceTest) {
|
|||||||
" attr { key: 'N' value { i: 2 } }"
|
" attr { key: 'N' value { i: 2 } }"
|
||||||
" input: ['A', 'B:0', 'B:1']}"
|
" input: ['A', 'B:0', 'B:1']}"
|
||||||
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
" input: ['C', 'D'] }", kGPUDevice);
|
" input: ['C', 'D'] }",
|
||||||
|
kGPUDevice);
|
||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
"A(Const);B(InputList);C(Input);D(Concat);E(Zeta)|A->D;"
|
"A(Const);B(InputList);C(Input);D(Concat);E(Zeta)|A->D;"
|
||||||
"B->D:1;B:1->D:2;C->E;D->E:1");
|
"B->D:1;B:1->D:2;C->E;D->E:1");
|
||||||
@ -1788,7 +1799,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_DeviceTest) {
|
|||||||
" attr { key: 'N' value { i: 2 } }"
|
" attr { key: 'N' value { i: 2 } }"
|
||||||
" input: ['B:0', 'B:1', 'A']}"
|
" input: ['B:0', 'B:1', 'A']}"
|
||||||
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
" input: ['C', 'D'] }", kGPUDevice);
|
" input: ['C', 'D'] }",
|
||||||
|
kGPUDevice);
|
||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
"A(Const);B(InputList);C(Input);D(ConcatV2);E(Zeta)|"
|
"A(Const);B(InputList);C(Input);D(ConcatV2);E(Zeta)|"
|
||||||
"A->D:2;B->D;B:1->D:1;C->E;D->E:1");
|
"A->D:2;B->D;B:1->D:1;C->E;D->E:1");
|
||||||
@ -1808,7 +1820,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNorm_DeviceTest) {
|
|||||||
" attr { key: 'is_training' value { b: true } }"
|
" attr { key: 'is_training' value { b: true } }"
|
||||||
" input: ['A', 'B', 'C', 'D', 'E'] }"
|
" input: ['A', 'B', 'C', 'D', 'E'] }"
|
||||||
"node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
"node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
" input: ['A', 'F'] }", kGPUDevice);
|
" input: ['A', 'F'] }",
|
||||||
|
kGPUDevice);
|
||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
"A(Input);B(Input);C(Input);D(Input);E(Input);"
|
"A(Input);B(Input);C(Input);D(Input);E(Input);"
|
||||||
"F(FusedBatchNorm);G(Zeta)|A->F;A->G;B->F:1;C->F:2;D->F:3;"
|
"F(FusedBatchNorm);G(Zeta)|A->F;A->G;B->F:1;C->F:2;D->F:3;"
|
||||||
@ -1837,7 +1850,8 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_DeviceTest) {
|
|||||||
"node { name: 'Y' op: 'Input'}"
|
"node { name: 'Y' op: 'Input'}"
|
||||||
"node { name: 'Z' op: 'Zeta'"
|
"node { name: 'Z' op: 'Zeta'"
|
||||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||||
" input: ['E', 'Y']}", kGPUDevice);
|
" input: ['E', 'Y']}",
|
||||||
|
kGPUDevice);
|
||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
"A(Input);B(Input);C(_MklConv2D);D(Input);E(BiasAdd);"
|
"A(Input);B(Input);C(_MklConv2D);D(Input);E(BiasAdd);"
|
||||||
"M(_MklInput);N(_MklInput);Y(Input);Z(Zeta)|A->C;"
|
"M(_MklInput);N(_MklInput);Y(Input);Z(Zeta)|A->C;"
|
||||||
@ -1972,8 +1986,10 @@ REGISTER_OP("InputList").Output("o: N * float").Attr("N: int").SetIsStateful();
|
|||||||
REGISTER_OP("HalfInput").Output("o: half").SetIsStateful();
|
REGISTER_OP("HalfInput").Output("o: half").SetIsStateful();
|
||||||
REGISTER_OP("Int32Input").Output("o: int32").SetIsStateful();
|
REGISTER_OP("Int32Input").Output("o: int32").SetIsStateful();
|
||||||
REGISTER_OP("_MklInput").Output("o: uint8").SetIsStateful();
|
REGISTER_OP("_MklInput").Output("o: uint8").SetIsStateful();
|
||||||
REGISTER_OP("_MklInput2").Output("o: uint8")
|
REGISTER_OP("_MklInput2")
|
||||||
.Output("o1: uint8").SetIsStateful();
|
.Output("o: uint8")
|
||||||
|
.Output("o1: uint8")
|
||||||
|
.SetIsStateful();
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////
|
||||||
// Unit tests related to node merge optiimization
|
// Unit tests related to node merge optiimization
|
||||||
@ -2492,11 +2508,12 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Basic) {
|
|||||||
" input: ['A', 'B:0', 'B:1']}"
|
" input: ['A', 'B:0', 'B:1']}"
|
||||||
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
" input: ['C', 'D'] }");
|
" input: ['C', 'D'] }");
|
||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(
|
||||||
"A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);"
|
DoMklLayoutOptimizationPass(),
|
||||||
"DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;A:control->DMT/_0:control;"
|
"A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);"
|
||||||
"A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;"
|
"DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;A:control->DMT/_0:control;"
|
||||||
"B:1->D:2;C->E;D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
|
"A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;"
|
||||||
|
"B:1->D:2;C->E;D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Concat with 2 Mkl layers feeding it
|
// Concat with 2 Mkl layers feeding it
|
||||||
@ -2891,7 +2908,8 @@ TEST_F(MklLayoutPassTest, MaxPoolLRN_Positive) {
|
|||||||
"node { name: 'H' op: 'Input'}"
|
"node { name: 'H' op: 'Input'}"
|
||||||
"node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
"node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
" input: ['H', 'G'] }");
|
" input: ['H', 'G'] }");
|
||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(
|
||||||
|
DoMklLayoutOptimizationPass(),
|
||||||
"A(Input);B(_MklLRN);C(_MklMaxPool);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
"A(Input);B(_MklLRN);C(_MklMaxPool);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
||||||
"DMT/_2(Const);E(_MklMaxPoolGrad);F(Input);G(_MklLRNGrad);H(Input);"
|
"DMT/_2(Const);E(_MklMaxPoolGrad);F(Input);G(_MklLRNGrad);H(Input);"
|
||||||
"I(Zeta)|A->B;A:control->DMT/_0:control;B->C;B->E;B->G:2;B:1->G:3;"
|
"I(Zeta)|A->B;A:control->DMT/_0:control;B->C;B->E;B->G:2;B:1->G:3;"
|
||||||
@ -3258,7 +3276,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_DeviceTest) {
|
|||||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||||
" input: ['A', 'B']}"
|
" input: ['A', 'B']}"
|
||||||
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
" input: ['B', 'C'] }", kGPUDevice);
|
" input: ['B', 'C'] }",
|
||||||
|
kGPUDevice);
|
||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
"A(Input);B(Input);C(Conv2D);D(Zeta)|A->C;B->C:1;B->D;C->D:1");
|
"A(Input);B(Input);C(Conv2D);D(Zeta)|A->C;B->C:1;B->D;C->D:1");
|
||||||
}
|
}
|
||||||
@ -3284,7 +3303,8 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_DeviceTest) {
|
|||||||
"node { name: 'F' op: 'BiasAddGrad'"
|
"node { name: 'F' op: 'BiasAddGrad'"
|
||||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
" input: ['E'] }", kGPUDevice);
|
" input: ['E'] }",
|
||||||
|
kGPUDevice);
|
||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
"A(Input);B(Input);C(Input);D(_MklConv2DWithBias);"
|
"A(Input);B(Input);C(Input);D(_MklConv2DWithBias);"
|
||||||
"E(Zeta);F(BiasAddGrad);M(_MklInput);N(_MklInput);"
|
"E(Zeta);F(BiasAddGrad);M(_MklInput);N(_MklInput);"
|
||||||
@ -3305,7 +3325,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradFilter_DeviceTest) {
|
|||||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||||
" input: ['A', 'B', 'C']}"
|
" input: ['A', 'B', 'C']}"
|
||||||
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
" input: ['A', 'D'] }", kGPUDevice);
|
" input: ['A', 'D'] }",
|
||||||
|
kGPUDevice);
|
||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
"A(Input);B(Int32Input);C(Input);D(Conv2DBackpropFilter);E(Zeta)|"
|
"A(Input);B(Int32Input);C(Input);D(Conv2DBackpropFilter);E(Zeta)|"
|
||||||
"A->D;A->E;B->D:1;C->D:2;D->E:1");
|
"A->D;A->E;B->D:1;C->D:2;D->E:1");
|
||||||
@ -3318,7 +3339,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Relu_DeviceTest) {
|
|||||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
" input: ['A'] }"
|
" input: ['A'] }"
|
||||||
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
" input: ['A', 'B'] }", kGPUDevice);
|
" input: ['A', 'B'] }",
|
||||||
|
kGPUDevice);
|
||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
"A(Input);B(Relu);C(Zeta)|A->B;A->C;B->C:1");
|
"A(Input);B(Relu);C(Zeta)|A->B;A->C;B->C:1");
|
||||||
}
|
}
|
||||||
@ -3331,7 +3353,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ReluGrad_DeviceTest) {
|
|||||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
" input: ['A', 'B'] }"
|
" input: ['A', 'B'] }"
|
||||||
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
" input: ['A', 'C'] }", kGPUDevice);
|
" input: ['A', 'C'] }",
|
||||||
|
kGPUDevice);
|
||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
"A(Input);B(Input);C(ReluGrad);D(Zeta)|A->C;A->D;B->C:1;C->D:1");
|
"A(Input);B(Input);C(ReluGrad);D(Zeta)|A->C;A->D;B->C:1;C->D:1");
|
||||||
}
|
}
|
||||||
@ -3347,7 +3370,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_MaxPool_DeviceTest) {
|
|||||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
" input: ['A'] }"
|
" input: ['A'] }"
|
||||||
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
" input: ['A', 'B'] }", kGPUDevice);
|
" input: ['A', 'B'] }",
|
||||||
|
kGPUDevice);
|
||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
"A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
|
"A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
|
||||||
}
|
}
|
||||||
@ -3363,7 +3387,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_DeviceTest) {
|
|||||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
" input: ['A'] }"
|
" input: ['A'] }"
|
||||||
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
" input: ['A', 'B'] }", kGPUDevice);
|
" input: ['A', 'B'] }",
|
||||||
|
kGPUDevice);
|
||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
"A(Input);B(AvgPool);C(Zeta)|A->B;A->C;B->C:1");
|
"A(Input);B(AvgPool);C(Zeta)|A->B;A->C;B->C:1");
|
||||||
}
|
}
|
||||||
@ -3384,7 +3409,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_DeviceTest) {
|
|||||||
" attr { key: 'N' value { i: 2 } }"
|
" attr { key: 'N' value { i: 2 } }"
|
||||||
" input: ['A', 'B:0', 'B:1']}"
|
" input: ['A', 'B:0', 'B:1']}"
|
||||||
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
" input: ['C', 'D'] }", kGPUDevice);
|
" input: ['C', 'D'] }",
|
||||||
|
kGPUDevice);
|
||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
"A(Const);B(InputList);C(Input);D(Concat);E(Zeta)|A->D;"
|
"A(Const);B(InputList);C(Input);D(Concat);E(Zeta)|A->D;"
|
||||||
"B->D:1;B:1->D:2;C->E;D->E:1");
|
"B->D:1;B:1->D:2;C->E;D->E:1");
|
||||||
@ -3406,7 +3432,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_DeviceTest) {
|
|||||||
" attr { key: 'N' value { i: 2 } }"
|
" attr { key: 'N' value { i: 2 } }"
|
||||||
" input: ['B:0', 'B:1', 'A']}"
|
" input: ['B:0', 'B:1', 'A']}"
|
||||||
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
" input: ['C', 'D'] }", kGPUDevice);
|
" input: ['C', 'D'] }",
|
||||||
|
kGPUDevice);
|
||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
"A(Const);B(InputList);C(Input);D(ConcatV2);E(Zeta)|"
|
"A(Const);B(InputList);C(Input);D(ConcatV2);E(Zeta)|"
|
||||||
"A->D:2;B->D;B:1->D:1;C->E;D->E:1");
|
"A->D:2;B->D;B:1->D:1;C->E;D->E:1");
|
||||||
@ -3426,7 +3453,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNorm_DeviceTest) {
|
|||||||
" attr { key: 'is_training' value { b: true } }"
|
" attr { key: 'is_training' value { b: true } }"
|
||||||
" input: ['A', 'B', 'C', 'D', 'E'] }"
|
" input: ['A', 'B', 'C', 'D', 'E'] }"
|
||||||
"node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
"node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
" input: ['A', 'F'] }", kGPUDevice);
|
" input: ['A', 'F'] }",
|
||||||
|
kGPUDevice);
|
||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
"A(Input);B(Input);C(Input);D(Input);E(Input);"
|
"A(Input);B(Input);C(Input);D(Input);E(Input);"
|
||||||
"F(FusedBatchNorm);G(Zeta)|A->F;A->G;B->F:1;C->F:2;D->F:3;"
|
"F(FusedBatchNorm);G(Zeta)|A->F;A->G;B->F:1;C->F:2;D->F:3;"
|
||||||
@ -3455,7 +3483,8 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_DeviceTest) {
|
|||||||
"node { name: 'Y' op: 'Input'}"
|
"node { name: 'Y' op: 'Input'}"
|
||||||
"node { name: 'Z' op: 'Zeta'"
|
"node { name: 'Z' op: 'Zeta'"
|
||||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||||
" input: ['E', 'Y']}", kGPUDevice);
|
" input: ['E', 'Y']}",
|
||||||
|
kGPUDevice);
|
||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
"A(Input);B(Input);C(_MklConv2D);D(Input);E(BiasAdd);"
|
"A(Input);B(Input);C(_MklConv2D);D(Input);E(BiasAdd);"
|
||||||
"M(_MklInput);N(_MklInput);Y(Input);Z(Zeta)|A->C;"
|
"M(_MklInput);N(_MklInput);Y(Input);Z(Zeta)|A->C;"
|
||||||
|
@ -33,8 +33,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/hash/hash.h"
|
#include "tensorflow/core/lib/hash/hash.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
|
||||||
#include "tensorflow/core/graph/mkl_tfconversion_pass.h"
|
|
||||||
#include "tensorflow/core/graph/mkl_graph_util.h"
|
#include "tensorflow/core/graph/mkl_graph_util.h"
|
||||||
|
#include "tensorflow/core/graph/mkl_tfconversion_pass.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -152,12 +152,12 @@ Status MklToTfConversionPass::InsertConversionNodeOnEdge(
|
|||||||
string data_format;
|
string data_format;
|
||||||
|
|
||||||
TF_CHECK_OK(GetNodeAttr(src->def(), "T", &src_datatype));
|
TF_CHECK_OK(GetNodeAttr(src->def(), "T", &src_datatype));
|
||||||
bool dst_dtype_found = GetNodeAttr(dst->def(), "T", &dst_datatype) ==
|
bool dst_dtype_found =
|
||||||
Status::OK();
|
GetNodeAttr(dst->def(), "T", &dst_datatype) == Status::OK();
|
||||||
// We compare source and destination datatypes only when both are found.
|
// We compare source and destination datatypes only when both are found.
|
||||||
if (dst_dtype_found && (src_datatype != dst_datatype)) {
|
if (dst_dtype_found && (src_datatype != dst_datatype)) {
|
||||||
string err_msg = "T attribute of " + src->name() + " and " +
|
string err_msg = "T attribute of " + src->name() + " and " + dst->name() +
|
||||||
dst->name() + " do not match. Will not insert" +
|
" do not match. Will not insert" +
|
||||||
" MklToTf node in such case.";
|
" MklToTf node in such case.";
|
||||||
return Status(error::Code::INVALID_ARGUMENT, err_msg.c_str());
|
return Status(error::Code::INVALID_ARGUMENT, err_msg.c_str());
|
||||||
}
|
}
|
||||||
@ -325,12 +325,12 @@ bool MklToTfConversionPass::RunPass(std::unique_ptr<Graph>* g) {
|
|||||||
// may not be Mkl node.
|
// may not be Mkl node.
|
||||||
DataType src_datatype;
|
DataType src_datatype;
|
||||||
DataType dst_datatype;
|
DataType dst_datatype;
|
||||||
bool src_is_mkl_op = (GetNodeAttr(src->def(), "T", &src_datatype) ==
|
bool src_is_mkl_op =
|
||||||
Status::OK() &&
|
(GetNodeAttr(src->def(), "T", &src_datatype) == Status::OK() &&
|
||||||
IsMklSupportedOp(src->type_string(), src_datatype));
|
IsMklSupportedOp(src->type_string(), src_datatype));
|
||||||
bool dst_is_mkl_op = (GetNodeAttr(dst->def(), "T", &dst_datatype) ==
|
bool dst_is_mkl_op =
|
||||||
Status::OK() &&
|
(GetNodeAttr(dst->def(), "T", &dst_datatype) == Status::OK() &&
|
||||||
IsMklSupportedOp(dst->type_string(), dst_datatype));
|
IsMklSupportedOp(dst->type_string(), dst_datatype));
|
||||||
|
|
||||||
// Check if src with is Mkl-compliant, while dst is not Mkl-compliant.
|
// Check if src with is Mkl-compliant, while dst is not Mkl-compliant.
|
||||||
if (src_is_mkl_op && !dst_is_mkl_op) {
|
if (src_is_mkl_op && !dst_is_mkl_op) {
|
||||||
|
@ -40,7 +40,7 @@ REGISTER_KERNEL_BUILDER(
|
|||||||
#ifdef TENSORFLOW_USE_SYCL
|
#ifdef TENSORFLOW_USE_SYCL
|
||||||
REGISTER_KERNEL_BUILDER(
|
REGISTER_KERNEL_BUILDER(
|
||||||
Name("HostConst").Device(DEVICE_SYCL).HostMemory("output"), HostConstantOp);
|
Name("HostConst").Device(DEVICE_SYCL).HostMemory("output"), HostConstantOp);
|
||||||
#endif // TENSORFLOW_USE_SYCL
|
#endif // TENSORFLOW_USE_SYCL
|
||||||
|
|
||||||
// Register the HostConst Op
|
// Register the HostConst Op
|
||||||
// Returns a constant tensor on the host. Useful for writing C++ tests
|
// Returns a constant tensor on the host. Useful for writing C++ tests
|
||||||
|
@ -114,7 +114,7 @@ class PeriodicFunction {
|
|||||||
void RunLoop(int64 start) LOCKS_EXCLUDED(mutex_);
|
void RunLoop(int64 start) LOCKS_EXCLUDED(mutex_);
|
||||||
|
|
||||||
const std::function<void()> function_; // Actual client function
|
const std::function<void()> function_; // Actual client function
|
||||||
const int64 interval_micros_; // Interval between calls.
|
const int64 interval_micros_; // Interval between calls.
|
||||||
const Options options_;
|
const Options options_;
|
||||||
|
|
||||||
// Protects state below.
|
// Protects state below.
|
||||||
|
@ -55,15 +55,14 @@ Status ScheduleTask(size_t task_size, BatchScheduler<FakeTask>* scheduler) {
|
|||||||
// use the clock to be destroyed.
|
// use the clock to be destroyed.
|
||||||
std::unique_ptr<Thread> CreateFakeClockAdvancerThread(
|
std::unique_ptr<Thread> CreateFakeClockAdvancerThread(
|
||||||
test_util::FakeClockEnv* env, Notification* start, Notification* stop) {
|
test_util::FakeClockEnv* env, Notification* start, Notification* stop) {
|
||||||
return std::unique_ptr<Thread>(
|
return std::unique_ptr<Thread>(Env::Default()->StartThread(
|
||||||
Env::Default()->StartThread({}, "FakeClockAdvancerThread",
|
{}, "FakeClockAdvancerThread", [env, start, stop] {
|
||||||
[env, start, stop] {
|
start->WaitForNotification();
|
||||||
start->WaitForNotification();
|
while (!stop->HasBeenNotified()) {
|
||||||
while (!stop->HasBeenNotified()) {
|
env->AdvanceByMicroseconds(10);
|
||||||
env->AdvanceByMicroseconds(10);
|
Env::Default()->SleepForMicroseconds(10);
|
||||||
Env::Default()->SleepForMicroseconds(10);
|
}
|
||||||
}
|
}));
|
||||||
}));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SharedBatchSchedulerTest, Basic) {
|
TEST(SharedBatchSchedulerTest, Basic) {
|
||||||
@ -258,7 +257,7 @@ TEST(SharedBatchSchedulerTest, ObeysTimeout) {
|
|||||||
TEST(SharedBatchSchedulerTest, ObeysTimeoutWithRealClock) {
|
TEST(SharedBatchSchedulerTest, ObeysTimeoutWithRealClock) {
|
||||||
Notification first_batch_processed, second_batch_processed;
|
Notification first_batch_processed, second_batch_processed;
|
||||||
auto callback = [&first_batch_processed, &second_batch_processed](
|
auto callback = [&first_batch_processed, &second_batch_processed](
|
||||||
std::unique_ptr<Batch<FakeTask>> batch) {
|
std::unique_ptr<Batch<FakeTask>> batch) {
|
||||||
ASSERT_TRUE(batch->IsClosed());
|
ASSERT_TRUE(batch->IsClosed());
|
||||||
if (batch->size() == 1) {
|
if (batch->size() == 1) {
|
||||||
first_batch_processed.Notify();
|
first_batch_processed.Notify();
|
||||||
@ -301,7 +300,7 @@ TEST(SharedBatchSchedulerTest,
|
|||||||
{
|
{
|
||||||
Notification first_batch_processed, second_batch_processed;
|
Notification first_batch_processed, second_batch_processed;
|
||||||
auto callback = [&first_batch_processed, &second_batch_processed](
|
auto callback = [&first_batch_processed, &second_batch_processed](
|
||||||
std::unique_ptr<Batch<FakeTask>> batch) {
|
std::unique_ptr<Batch<FakeTask>> batch) {
|
||||||
ASSERT_TRUE(batch->IsClosed());
|
ASSERT_TRUE(batch->IsClosed());
|
||||||
if (batch->size() == 1) {
|
if (batch->size() == 1) {
|
||||||
first_batch_processed.Notify();
|
first_batch_processed.Notify();
|
||||||
@ -349,7 +348,7 @@ TEST(SharedBatchSchedulerTest, Fairness) {
|
|||||||
auto queue_0_callback = [&queue_0_first_batch_scheduled,
|
auto queue_0_callback = [&queue_0_first_batch_scheduled,
|
||||||
&queue_0_first_batch_proceed,
|
&queue_0_first_batch_proceed,
|
||||||
&queue_0_second_batch_scheduled](
|
&queue_0_second_batch_scheduled](
|
||||||
std::unique_ptr<Batch<FakeTask>> batch) {
|
std::unique_ptr<Batch<FakeTask>> batch) {
|
||||||
if (!queue_0_first_batch_scheduled.HasBeenNotified()) {
|
if (!queue_0_first_batch_scheduled.HasBeenNotified()) {
|
||||||
queue_0_first_batch_scheduled.Notify();
|
queue_0_first_batch_scheduled.Notify();
|
||||||
queue_0_first_batch_proceed.WaitForNotification();
|
queue_0_first_batch_proceed.WaitForNotification();
|
||||||
@ -467,7 +466,7 @@ TEST(SharedBatchSchedulerTest, ConstMethods) {
|
|||||||
TEST(SharedBatchSchedulerTest, OneFullQueueDoesntBlockOtherQueues) {
|
TEST(SharedBatchSchedulerTest, OneFullQueueDoesntBlockOtherQueues) {
|
||||||
Notification queue_0_processing, queue_0_proceed;
|
Notification queue_0_processing, queue_0_proceed;
|
||||||
auto queue_0_callback = [&queue_0_processing, &queue_0_proceed](
|
auto queue_0_callback = [&queue_0_processing, &queue_0_proceed](
|
||||||
std::unique_ptr<Batch<FakeTask>> batch) {
|
std::unique_ptr<Batch<FakeTask>> batch) {
|
||||||
if (!queue_0_processing.HasBeenNotified()) {
|
if (!queue_0_processing.HasBeenNotified()) {
|
||||||
queue_0_processing.Notify();
|
queue_0_processing.Notify();
|
||||||
queue_0_proceed.WaitForNotification();
|
queue_0_proceed.WaitForNotification();
|
||||||
|
@ -92,7 +92,6 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
class Iterator : public DatasetIterator<Dataset> {
|
class Iterator : public DatasetIterator<Dataset> {
|
||||||
public:
|
public:
|
||||||
explicit Iterator(const Params& params)
|
explicit Iterator(const Params& params)
|
||||||
|
@ -22,7 +22,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/random/random.h"
|
#include "tensorflow/core/lib/random/random.h"
|
||||||
#include "tensorflow/core/platform/notification.h"
|
#include "tensorflow/core/platform/notification.h"
|
||||||
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
/* static */
|
/* static */
|
||||||
@ -185,8 +184,7 @@ Status CapturedFunction::MaybeInstantiate(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CapturedFunction::Run(IteratorContext* ctx,
|
Status CapturedFunction::Run(IteratorContext* ctx, std::vector<Tensor>&& args,
|
||||||
std::vector<Tensor>&& args,
|
|
||||||
std::vector<Tensor>* rets) {
|
std::vector<Tensor>* rets) {
|
||||||
FunctionLibraryRuntime::Handle handle;
|
FunctionLibraryRuntime::Handle handle;
|
||||||
TF_RETURN_IF_ERROR(MaybeInstantiate(ctx, &handle));
|
TF_RETURN_IF_ERROR(MaybeInstantiate(ctx, &handle));
|
||||||
|
@ -128,8 +128,8 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
while (i_ < dataset()->count_) {
|
while (i_ < dataset()->count_) {
|
||||||
// Fetch and throw away Tensors.
|
// Fetch and throw away Tensors.
|
||||||
std::vector<Tensor> dummy_out_tensors;
|
std::vector<Tensor> dummy_out_tensors;
|
||||||
TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &dummy_out_tensors,
|
TF_RETURN_IF_ERROR(
|
||||||
end_of_sequence));
|
input_impl_->GetNext(ctx, &dummy_out_tensors, end_of_sequence));
|
||||||
if (*end_of_sequence) {
|
if (*end_of_sequence) {
|
||||||
// We reached the end before the count was reached.
|
// We reached the end before the count was reached.
|
||||||
input_impl_.reset();
|
input_impl_.reset();
|
||||||
@ -140,8 +140,8 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Return GetNext() on the underlying iterator.
|
// Return GetNext() on the underlying iterator.
|
||||||
TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, out_tensors,
|
TF_RETURN_IF_ERROR(
|
||||||
end_of_sequence));
|
input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
|
||||||
if (*end_of_sequence) {
|
if (*end_of_sequence) {
|
||||||
input_impl_.reset();
|
input_impl_.reset();
|
||||||
}
|
}
|
||||||
@ -184,8 +184,7 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("SkipDataset").Device(DEVICE_CPU),
|
REGISTER_KERNEL_BUILDER(Name("SkipDataset").Device(DEVICE_CPU), SkipDatasetOp);
|
||||||
SkipDatasetOp);
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
|
|
||||||
#include "tensorflow/cc/ops/standard_ops.h"
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace fuzzing {
|
namespace fuzzing {
|
||||||
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
|
|
||||||
#include "tensorflow/cc/ops/standard_ops.h"
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace fuzzing {
|
namespace fuzzing {
|
||||||
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
|
|
||||||
#include "tensorflow/cc/ops/standard_ops.h"
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace fuzzing {
|
namespace fuzzing {
|
||||||
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
|
|
||||||
#include "tensorflow/cc/ops/standard_ops.h"
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace fuzzing {
|
namespace fuzzing {
|
||||||
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
|
|
||||||
#include "tensorflow/cc/ops/standard_ops.h"
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace fuzzing {
|
namespace fuzzing {
|
||||||
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
|
|
||||||
#include "tensorflow/cc/ops/standard_ops.h"
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace fuzzing {
|
namespace fuzzing {
|
||||||
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
|
|
||||||
#include "tensorflow/cc/ops/standard_ops.h"
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace fuzzing {
|
namespace fuzzing {
|
||||||
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
|
|
||||||
#include "tensorflow/cc/ops/standard_ops.h"
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace fuzzing {
|
namespace fuzzing {
|
||||||
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
|
|
||||||
#include "tensorflow/cc/ops/standard_ops.h"
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace fuzzing {
|
namespace fuzzing {
|
||||||
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
|
|
||||||
#include "tensorflow/cc/ops/standard_ops.h"
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace fuzzing {
|
namespace fuzzing {
|
||||||
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
|
|
||||||
#include "tensorflow/cc/ops/standard_ops.h"
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace fuzzing {
|
namespace fuzzing {
|
||||||
|
@ -46,7 +46,7 @@ GraphTransferUtils::GetTopNFloatResults(const float* const data,
|
|||||||
GetTopNFloatResults(data, labels, element_count);
|
GetTopNFloatResults(data, labels, element_count);
|
||||||
LOG(INFO) << "=== Dump ranking ===";
|
LOG(INFO) << "=== Dump ranking ===";
|
||||||
for (int i = 0; i < top_n; ++i) {
|
for (int i = 0; i < top_n; ++i) {
|
||||||
const std::tuple<float, int, string> &entry = queue.top();
|
const std::tuple<float, int, string>& entry = queue.top();
|
||||||
LOG(INFO) << i << ": " << std::get<1>(entry) << ", " << std::get<2>(entry)
|
LOG(INFO) << i << ": " << std::get<1>(entry) << ", " << std::get<2>(entry)
|
||||||
<< ", " << std::get<0>(entry);
|
<< ", " << std::get<0>(entry);
|
||||||
queue.pop();
|
queue.pop();
|
||||||
|
@ -181,8 +181,8 @@ class GraphTransferer {
|
|||||||
void AppendNodeInputParams(const int id, const Node& node,
|
void AppendNodeInputParams(const int id, const Node& node,
|
||||||
const std::vector<int>& extra_inputs);
|
const std::vector<int>& extra_inputs);
|
||||||
|
|
||||||
void AppendNodeOutputParams(const ShapeRefiner& shape_refiner,
|
void AppendNodeOutputParams(const ShapeRefiner& shape_refiner, const int id,
|
||||||
const int id, const Node& node);
|
const Node& node);
|
||||||
|
|
||||||
static std::array<int64, SHAPE_ARRAY_SIZE> BuildShapeArray(
|
static std::array<int64, SHAPE_ARRAY_SIZE> BuildShapeArray(
|
||||||
const shape_inference::ShapeHandle& shape_handle,
|
const shape_inference::ShapeHandle& shape_handle,
|
||||||
|
@ -42,8 +42,7 @@ constexpr float VALUE_TOLERANCE_FLOAT = 1e-8f;
|
|||||||
|
|
||||||
class GraphTransfererTest : public ::testing::Test {
|
class GraphTransfererTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
void SetUp() final {
|
void SetUp() final {}
|
||||||
}
|
|
||||||
|
|
||||||
GraphTransferer gt_;
|
GraphTransferer gt_;
|
||||||
};
|
};
|
||||||
@ -61,7 +60,7 @@ class TestGraphTransferOpsDefinitions : public IRemoteFusedGraphOpsDefinitions {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const std::vector<string> op_types_{"INPUT", "OUTPUT", "Conv2D",
|
const std::vector<string> op_types_{"INPUT", "OUTPUT", "Conv2D",
|
||||||
|
@ -420,7 +420,7 @@ TEST(GraphTransferer,
|
|||||||
false, // is_text_proto
|
false, // is_text_proto
|
||||||
false, // shape_inference_for_unknown_shape
|
false, // shape_inference_for_unknown_shape
|
||||||
true // dry_run_for_unknown_shape
|
true // dry_run_for_unknown_shape
|
||||||
);
|
);
|
||||||
ASSERT_TRUE(status.ok()) << status;
|
ASSERT_TRUE(status.ok()) << status;
|
||||||
prof.Stop();
|
prof.Stop();
|
||||||
prof.DumpStatistics("LoadGraphFromProtoFile");
|
prof.DumpStatistics("LoadGraphFromProtoFile");
|
||||||
@ -487,7 +487,7 @@ TEST(GraphTransferer,
|
|||||||
false, // is_text_proto
|
false, // is_text_proto
|
||||||
true, // shape_inference_for_unknown_shape
|
true, // shape_inference_for_unknown_shape
|
||||||
false // dry_run_for_unknown_shape
|
false // dry_run_for_unknown_shape
|
||||||
);
|
);
|
||||||
ASSERT_TRUE(status.ok()) << status;
|
ASSERT_TRUE(status.ok()) << status;
|
||||||
prof.Stop();
|
prof.Stop();
|
||||||
prof.DumpStatistics("LoadGraphFromProtoFile");
|
prof.DumpStatistics("LoadGraphFromProtoFile");
|
||||||
@ -556,7 +556,7 @@ TEST(GraphTransferer, DISABLED_CheckShapeInferencePerformance) {
|
|||||||
false, // is_text_proto
|
false, // is_text_proto
|
||||||
false, // shape_inference_for_unknown_shape
|
false, // shape_inference_for_unknown_shape
|
||||||
true // dry_run_for_unknown_shape
|
true // dry_run_for_unknown_shape
|
||||||
);
|
);
|
||||||
const GraphTransferInfo& gfi0 = gt0.GetGraphTransferInfo();
|
const GraphTransferInfo& gfi0 = gt0.GetGraphTransferInfo();
|
||||||
|
|
||||||
ASSERT_TRUE(status.ok());
|
ASSERT_TRUE(status.ok());
|
||||||
@ -576,7 +576,7 @@ TEST(GraphTransferer, DISABLED_CheckShapeInferencePerformance) {
|
|||||||
false, // is_text_proto
|
false, // is_text_proto
|
||||||
true, // shape_inference_for_unknown_shape
|
true, // shape_inference_for_unknown_shape
|
||||||
false // dry_run_for_unknown_shape
|
false // dry_run_for_unknown_shape
|
||||||
);
|
);
|
||||||
const GraphTransferInfo& gfi1 = gt1.GetGraphTransferInfo();
|
const GraphTransferInfo& gfi1 = gt1.GetGraphTransferInfo();
|
||||||
|
|
||||||
ASSERT_TRUE(status.ok());
|
ASSERT_TRUE(status.ok());
|
||||||
|
@ -71,10 +71,10 @@ class NeonDepthwiseConv2dNativeOp : public BinaryOp<float> {
|
|||||||
filter.shape().DebugString()));
|
filter.shape().DebugString()));
|
||||||
|
|
||||||
const int32 in_depth = input.dim_size(3);
|
const int32 in_depth = input.dim_size(3);
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(context, in_depth == filter.dim_size(2),
|
||||||
context, in_depth == filter.dim_size(2),
|
errors::InvalidArgument(
|
||||||
errors::InvalidArgument("input and filter must have the same depth: ",
|
"input and filter must have the same depth: ", in_depth,
|
||||||
in_depth, " vs ", filter.dim_size(2)));
|
" vs ", filter.dim_size(2)));
|
||||||
const int32 batch = input.dim_size(0);
|
const int32 batch = input.dim_size(0);
|
||||||
const int32 input_rows = input.dim_size(1);
|
const int32 input_rows = input.dim_size(1);
|
||||||
const int32 input_cols = input.dim_size(2);
|
const int32 input_cols = input.dim_size(2);
|
||||||
|
@ -131,7 +131,7 @@ inline tensorflow::string* TfCheckOpHelper(::tensorflow::Status v,
|
|||||||
while (auto _result = ::tensorflow::TfCheckOpHelper(val, #val)) \
|
while (auto _result = ::tensorflow::TfCheckOpHelper(val, #val)) \
|
||||||
LOG(level) << *(_result)
|
LOG(level) << *(_result)
|
||||||
|
|
||||||
#define TF_CHECK_OK(val) TF_DO_CHECK_OK(val, FATAL)
|
#define TF_CHECK_OK(val) TF_DO_CHECK_OK(val, FATAL)
|
||||||
#define TF_QCHECK_OK(val) TF_DO_CHECK_OK(val, QFATAL)
|
#define TF_QCHECK_OK(val) TF_DO_CHECK_OK(val, QFATAL)
|
||||||
|
|
||||||
// DEBUG only version of TF_CHECK_OK. Compiler still parses 'val' even in opt
|
// DEBUG only version of TF_CHECK_OK. Compiler still parses 'val' even in opt
|
||||||
|
@ -66,7 +66,9 @@ struct EigenEnvironment {
|
|||||||
}
|
}
|
||||||
return Task{
|
return Task{
|
||||||
std::unique_ptr<TaskImpl>(new TaskImpl{
|
std::unique_ptr<TaskImpl>(new TaskImpl{
|
||||||
std::move(f), Context(ContextKind::kThread), id,
|
std::move(f),
|
||||||
|
Context(ContextKind::kThread),
|
||||||
|
id,
|
||||||
}),
|
}),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -97,8 +97,8 @@ TEST(ThreadPool, ParallelForWithWorkerId) {
|
|||||||
}
|
}
|
||||||
pool.ParallelForWithWorkerId(
|
pool.ParallelForWithWorkerId(
|
||||||
kWorkItems, kHugeCost,
|
kWorkItems, kHugeCost,
|
||||||
[&threads_running, &work, num_threads](
|
[&threads_running, &work, num_threads](int64 begin, int64 end,
|
||||||
int64 begin, int64 end, int64 id) {
|
int64 id) {
|
||||||
// Store true for the current thread, and assert that another thread
|
// Store true for the current thread, and assert that another thread
|
||||||
// is not running with the same id.
|
// is not running with the same id.
|
||||||
ASSERT_LE(0, id);
|
ASSERT_LE(0, id);
|
||||||
|
@ -18,12 +18,12 @@ limitations under the License.
|
|||||||
#include <mutex>
|
#include <mutex>
|
||||||
|
|
||||||
#include "sqlite3.h"
|
#include "sqlite3.h"
|
||||||
|
#include "tensorflow/core/lib/core/refcount.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
#include "tensorflow/core/platform/thread_annotations.h"
|
#include "tensorflow/core/platform/thread_annotations.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
#include "tensorflow/core/lib/core/refcount.h"
|
|
||||||
|
|
||||||
/// TensorFlow SQLite Veneer
|
/// TensorFlow SQLite Veneer
|
||||||
///
|
///
|
||||||
@ -121,10 +121,7 @@ class LOCKABLE Sqlite : public core::RefCounted {
|
|||||||
|
|
||||||
Sqlite(sqlite3* db, sqlite3_stmt* begin, sqlite3_stmt* commit,
|
Sqlite(sqlite3* db, sqlite3_stmt* begin, sqlite3_stmt* commit,
|
||||||
sqlite3_stmt* rollback) noexcept
|
sqlite3_stmt* rollback) noexcept
|
||||||
: db_(db),
|
: db_(db), begin_(begin), commit_(commit), rollback_(rollback) {}
|
||||||
begin_(begin),
|
|
||||||
commit_(commit),
|
|
||||||
rollback_(rollback) {}
|
|
||||||
|
|
||||||
sqlite3* const db_;
|
sqlite3* const db_;
|
||||||
sqlite3_stmt* const begin_;
|
sqlite3_stmt* const begin_;
|
||||||
@ -233,7 +230,8 @@ class SqliteStatement {
|
|||||||
/// freed until this statement is Reset() or finalized.
|
/// freed until this statement is Reset() or finalized.
|
||||||
void BindText(int parameter, const StringPiece& text) {
|
void BindText(int parameter, const StringPiece& text) {
|
||||||
Update(sqlite3_bind_text64(stmt_, parameter, text.data(), text.size(),
|
Update(sqlite3_bind_text64(stmt_, parameter, text.data(), text.size(),
|
||||||
SQLITE_TRANSIENT, SQLITE_UTF8), parameter);
|
SQLITE_TRANSIENT, SQLITE_UTF8),
|
||||||
|
parameter);
|
||||||
size_ += text.size();
|
size_ += text.size();
|
||||||
}
|
}
|
||||||
void BindText(const char* parameter, const StringPiece& text) {
|
void BindText(const char* parameter, const StringPiece& text) {
|
||||||
@ -241,7 +239,8 @@ class SqliteStatement {
|
|||||||
}
|
}
|
||||||
void BindTextUnsafe(int parameter, const StringPiece& text) {
|
void BindTextUnsafe(int parameter, const StringPiece& text) {
|
||||||
Update(sqlite3_bind_text64(stmt_, parameter, text.data(), text.size(),
|
Update(sqlite3_bind_text64(stmt_, parameter, text.data(), text.size(),
|
||||||
SQLITE_STATIC, SQLITE_UTF8), parameter);
|
SQLITE_STATIC, SQLITE_UTF8),
|
||||||
|
parameter);
|
||||||
size_ += text.size();
|
size_ += text.size();
|
||||||
}
|
}
|
||||||
void BindTextUnsafe(const char* parameter, const StringPiece& text) {
|
void BindTextUnsafe(const char* parameter, const StringPiece& text) {
|
||||||
@ -254,7 +253,8 @@ class SqliteStatement {
|
|||||||
/// freed until this statement is Reset() or finalized.
|
/// freed until this statement is Reset() or finalized.
|
||||||
void BindBlob(int parameter, const StringPiece& blob) {
|
void BindBlob(int parameter, const StringPiece& blob) {
|
||||||
Update(sqlite3_bind_blob64(stmt_, parameter, blob.data(), blob.size(),
|
Update(sqlite3_bind_blob64(stmt_, parameter, blob.data(), blob.size(),
|
||||||
SQLITE_TRANSIENT), parameter);
|
SQLITE_TRANSIENT),
|
||||||
|
parameter);
|
||||||
size_ += blob.size();
|
size_ += blob.size();
|
||||||
}
|
}
|
||||||
void BindBlob(const char* parameter, const StringPiece& blob) {
|
void BindBlob(const char* parameter, const StringPiece& blob) {
|
||||||
@ -262,7 +262,8 @@ class SqliteStatement {
|
|||||||
}
|
}
|
||||||
void BindBlobUnsafe(int parameter, const StringPiece& blob) {
|
void BindBlobUnsafe(int parameter, const StringPiece& blob) {
|
||||||
Update(sqlite3_bind_blob64(stmt_, parameter, blob.data(), blob.size(),
|
Update(sqlite3_bind_blob64(stmt_, parameter, blob.data(), blob.size(),
|
||||||
SQLITE_STATIC), parameter);
|
SQLITE_STATIC),
|
||||||
|
parameter);
|
||||||
size_ += blob.size();
|
size_ += blob.size();
|
||||||
}
|
}
|
||||||
void BindBlobUnsafe(const char* parameter, const StringPiece& text) {
|
void BindBlobUnsafe(const char* parameter, const StringPiece& text) {
|
||||||
@ -320,9 +321,7 @@ class SqliteStatement {
|
|||||||
|
|
||||||
/// \brief Move constructor, after which <other> is reset to empty.
|
/// \brief Move constructor, after which <other> is reset to empty.
|
||||||
SqliteStatement(SqliteStatement&& other) noexcept
|
SqliteStatement(SqliteStatement&& other) noexcept
|
||||||
: db_(other.db_),
|
: db_(other.db_), stmt_(other.stmt_), bind_error_(other.bind_error_) {
|
||||||
stmt_(other.stmt_),
|
|
||||||
bind_error_(other.bind_error_) {
|
|
||||||
other.db_ = nullptr;
|
other.db_ = nullptr;
|
||||||
other.stmt_ = nullptr;
|
other.stmt_ = nullptr;
|
||||||
other.bind_error_ = SQLITE_OK;
|
other.bind_error_ = SQLITE_OK;
|
||||||
|
@ -33,9 +33,7 @@ class SqliteTest : public ::testing::Test {
|
|||||||
db_->PrepareOrDie("CREATE TABLE T (a BLOB, b BLOB)").StepAndResetOrDie();
|
db_->PrepareOrDie("CREATE TABLE T (a BLOB, b BLOB)").StepAndResetOrDie();
|
||||||
}
|
}
|
||||||
|
|
||||||
void TearDown() override {
|
void TearDown() override { db_->Unref(); }
|
||||||
db_->Unref();
|
|
||||||
}
|
|
||||||
|
|
||||||
Sqlite* db_;
|
Sqlite* db_;
|
||||||
bool is_done_;
|
bool is_done_;
|
||||||
@ -213,7 +211,7 @@ TEST_F(SqliteTest, BindFailed) {
|
|||||||
Status s = stmt.StepOnce();
|
Status s = stmt.StepOnce();
|
||||||
EXPECT_NE(string::npos,
|
EXPECT_NE(string::npos,
|
||||||
s.error_message().find("INSERT INTO T (a) VALUES (123)"))
|
s.error_message().find("INSERT INTO T (a) VALUES (123)"))
|
||||||
<< s.error_message();
|
<< s.error_message();
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(SqliteTest, SnappyExtension) {
|
TEST_F(SqliteTest, SnappyExtension) {
|
||||||
@ -226,7 +224,7 @@ TEST_F(SqliteTest, SnappyBinaryCompatibility) {
|
|||||||
EXPECT_EQ(
|
EXPECT_EQ(
|
||||||
"today is the end of the republic",
|
"today is the end of the republic",
|
||||||
db_->PrepareOrDie("SELECT UNSNAP(X'03207C746F6461792069732074686520656E64"
|
db_->PrepareOrDie("SELECT UNSNAP(X'03207C746F6461792069732074686520656E64"
|
||||||
"206F66207468652072657075626C6963')")
|
"206F66207468652072657075626C6963')")
|
||||||
.StepOnceOrDie()
|
.StepOnceOrDie()
|
||||||
.ColumnString(0));
|
.ColumnString(0));
|
||||||
}
|
}
|
||||||
|
@ -55,22 +55,21 @@ namespace gtl {
|
|||||||
template <typename F>
|
template <typename F>
|
||||||
class Cleanup {
|
class Cleanup {
|
||||||
public:
|
public:
|
||||||
Cleanup()
|
Cleanup() : released_(true), f_() {}
|
||||||
: released_(true), f_() {}
|
|
||||||
|
|
||||||
template <typename G>
|
template <typename G>
|
||||||
explicit Cleanup(G&& f) // NOLINT
|
explicit Cleanup(G&& f) // NOLINT
|
||||||
: f_(std::forward<G>(f)) {} // NOLINT(build/c++11)
|
: f_(std::forward<G>(f)) {} // NOLINT(build/c++11)
|
||||||
|
|
||||||
Cleanup(Cleanup&& src) // NOLINT
|
Cleanup(Cleanup&& src) // NOLINT
|
||||||
: released_(src.is_released()), f_(src.release()) { }
|
: released_(src.is_released()), f_(src.release()) {}
|
||||||
|
|
||||||
// Implicitly move-constructible from any compatible Cleanup<G>.
|
// Implicitly move-constructible from any compatible Cleanup<G>.
|
||||||
// The source will be released as if src.release() were called.
|
// The source will be released as if src.release() were called.
|
||||||
// A moved-from Cleanup can be safely destroyed or reassigned.
|
// A moved-from Cleanup can be safely destroyed or reassigned.
|
||||||
template <typename G>
|
template <typename G>
|
||||||
Cleanup(Cleanup<G>&& src) // NOLINT
|
Cleanup(Cleanup<G>&& src) // NOLINT
|
||||||
: released_(src.is_released()), f_(src.release()) { }
|
: released_(src.is_released()), f_(src.release()) {}
|
||||||
|
|
||||||
// Assignment to a Cleanup object behaves like destroying it
|
// Assignment to a Cleanup object behaves like destroying it
|
||||||
// and making a new one in its place, analogous to unique_ptr
|
// and making a new one in its place, analogous to unique_ptr
|
||||||
@ -102,8 +101,8 @@ class Cleanup {
|
|||||||
F f_;
|
F f_;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <int&... ExplicitParameterBarrier,
|
template <int&... ExplicitParameterBarrier, typename F,
|
||||||
typename F, typename DecayF = typename std::decay<F>::type>
|
typename DecayF = typename std::decay<F>::type>
|
||||||
TF_MUST_USE_RESULT Cleanup<DecayF> MakeCleanup(F&& f) {
|
TF_MUST_USE_RESULT Cleanup<DecayF> MakeCleanup(F&& f) {
|
||||||
return Cleanup<DecayF>(std::forward<F>(f));
|
return Cleanup<DecayF>(std::forward<F>(f));
|
||||||
}
|
}
|
||||||
|
@ -65,15 +65,14 @@ TEST(CleanupTest, Release) {
|
|||||||
TEST(FinallyTest, TypeErasedWithoutFactory) {
|
TEST(FinallyTest, TypeErasedWithoutFactory) {
|
||||||
string s = "active";
|
string s = "active";
|
||||||
{
|
{
|
||||||
AnyCleanup s_cleaner([&s]{ s.append(" clean"); });
|
AnyCleanup s_cleaner([&s] { s.append(" clean"); });
|
||||||
EXPECT_EQ("active", s);
|
EXPECT_EQ("active", s);
|
||||||
}
|
}
|
||||||
EXPECT_EQ("active clean", s);
|
EXPECT_EQ("active clean", s);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Appender {
|
struct Appender {
|
||||||
Appender(string* s, const string& msg)
|
Appender(string* s, const string& msg) : s_(s), msg_(msg) {}
|
||||||
: s_(s), msg_(msg) {}
|
|
||||||
void operator()() const { s_->append(msg_); }
|
void operator()() const { s_->append(msg_); }
|
||||||
string* s_;
|
string* s_;
|
||||||
string msg_;
|
string msg_;
|
||||||
@ -163,7 +162,12 @@ class CleanupReferenceTest : public ::testing::Test {
|
|||||||
int* i;
|
int* i;
|
||||||
F(int* cp, int* i) : cp(cp), i(i) {}
|
F(int* cp, int* i) : cp(cp), i(i) {}
|
||||||
F(const F& o) : cp(o.cp), i(o.i) { ++*cp; }
|
F(const F& o) : cp(o.cp), i(o.i) { ++*cp; }
|
||||||
F& operator=(const F& o) { cp = o.cp; i = o.i; ++*cp; return *this; }
|
F& operator=(const F& o) {
|
||||||
|
cp = o.cp;
|
||||||
|
i = o.i;
|
||||||
|
++*cp;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
F(F&&) = default;
|
F(F&&) = default;
|
||||||
F& operator=(F&&) = default;
|
F& operator=(F&&) = default;
|
||||||
void operator()() const { ++*i; }
|
void operator()() const { ++*i; }
|
||||||
@ -279,7 +283,7 @@ BENCHMARK(BM_AnyCleanup);
|
|||||||
|
|
||||||
void BM_AnyCleanupNoFactory(int iters) {
|
void BM_AnyCleanupNoFactory(int iters) {
|
||||||
while (iters--) {
|
while (iters--) {
|
||||||
AnyCleanup fin([]{Incr();});
|
AnyCleanup fin([] { Incr(); });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
BENCHMARK(BM_AnyCleanupNoFactory);
|
BENCHMARK(BM_AnyCleanupNoFactory);
|
||||||
|
@ -31,12 +31,12 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_LIB_GTL_INLINED_VECTOR_H_
|
#ifndef TENSORFLOW_LIB_GTL_INLINED_VECTOR_H_
|
||||||
#define TENSORFLOW_LIB_GTL_INLINED_VECTOR_H_
|
#define TENSORFLOW_LIB_GTL_INLINED_VECTOR_H_
|
||||||
|
|
||||||
#include <cstddef>
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
#include <sys/types.h>
|
#include <sys/types.h>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <cstddef>
|
||||||
#include <iterator>
|
#include <iterator>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
@ -407,7 +407,7 @@ class InlinedVector {
|
|||||||
};
|
};
|
||||||
// 2) Construct a T with args at not-yet-initialized memory pointed by dst.
|
// 2) Construct a T with args at not-yet-initialized memory pointed by dst.
|
||||||
struct Construct {
|
struct Construct {
|
||||||
template<class... Args>
|
template <class... Args>
|
||||||
void operator()(T* dst, Args&&... args) const {
|
void operator()(T* dst, Args&&... args) const {
|
||||||
new (dst) T(std::forward<Args>(args)...);
|
new (dst) T(std::forward<Args>(args)...);
|
||||||
}
|
}
|
||||||
|
@ -255,13 +255,13 @@ class IntType {
|
|||||||
value_ op arg_value; \
|
value_ op arg_value; \
|
||||||
return *this; \
|
return *this; \
|
||||||
}
|
}
|
||||||
INT_TYPE_ASSIGNMENT_OP(+= );
|
INT_TYPE_ASSIGNMENT_OP(+=);
|
||||||
INT_TYPE_ASSIGNMENT_OP(-= );
|
INT_TYPE_ASSIGNMENT_OP(-=);
|
||||||
INT_TYPE_ASSIGNMENT_OP(*= );
|
INT_TYPE_ASSIGNMENT_OP(*=);
|
||||||
INT_TYPE_ASSIGNMENT_OP(/= );
|
INT_TYPE_ASSIGNMENT_OP(/=);
|
||||||
INT_TYPE_ASSIGNMENT_OP(<<= ); // NOLINT
|
INT_TYPE_ASSIGNMENT_OP(<<=); // NOLINT
|
||||||
INT_TYPE_ASSIGNMENT_OP(>>= ); // NOLINT
|
INT_TYPE_ASSIGNMENT_OP(>>=); // NOLINT
|
||||||
INT_TYPE_ASSIGNMENT_OP(%= );
|
INT_TYPE_ASSIGNMENT_OP(%=);
|
||||||
#undef INT_TYPE_ASSIGNMENT_OP
|
#undef INT_TYPE_ASSIGNMENT_OP
|
||||||
|
|
||||||
ThisType& operator=(ValueType arg_value) {
|
ThisType& operator=(ValueType arg_value) {
|
||||||
@ -314,10 +314,10 @@ std::ostream& operator<<(std::ostream& os, // NOLINT
|
|||||||
INT_TYPE_ARITHMETIC_OP(+);
|
INT_TYPE_ARITHMETIC_OP(+);
|
||||||
INT_TYPE_ARITHMETIC_OP(-);
|
INT_TYPE_ARITHMETIC_OP(-);
|
||||||
INT_TYPE_ARITHMETIC_OP(*);
|
INT_TYPE_ARITHMETIC_OP(*);
|
||||||
INT_TYPE_ARITHMETIC_OP(/ );
|
INT_TYPE_ARITHMETIC_OP(/);
|
||||||
INT_TYPE_ARITHMETIC_OP(<< ); // NOLINT
|
INT_TYPE_ARITHMETIC_OP(<<); // NOLINT
|
||||||
INT_TYPE_ARITHMETIC_OP(>> ); // NOLINT
|
INT_TYPE_ARITHMETIC_OP(>>); // NOLINT
|
||||||
INT_TYPE_ARITHMETIC_OP(% );
|
INT_TYPE_ARITHMETIC_OP(%);
|
||||||
#undef INT_TYPE_ARITHMETIC_OP
|
#undef INT_TYPE_ARITHMETIC_OP
|
||||||
|
|
||||||
// -- NON-MEMBER COMPARISON OPERATORS ------------------------------------------
|
// -- NON-MEMBER COMPARISON OPERATORS ------------------------------------------
|
||||||
@ -345,12 +345,12 @@ INT_TYPE_ARITHMETIC_OP(% );
|
|||||||
IntType<IntTypeName, ValueType> id) { \
|
IntType<IntTypeName, ValueType> id) { \
|
||||||
return val op id.value(); \
|
return val op id.value(); \
|
||||||
}
|
}
|
||||||
INT_TYPE_COMPARISON_OP(== ); // NOLINT
|
INT_TYPE_COMPARISON_OP(==); // NOLINT
|
||||||
INT_TYPE_COMPARISON_OP(!= ); // NOLINT
|
INT_TYPE_COMPARISON_OP(!=); // NOLINT
|
||||||
INT_TYPE_COMPARISON_OP(< ); // NOLINT
|
INT_TYPE_COMPARISON_OP(<); // NOLINT
|
||||||
INT_TYPE_COMPARISON_OP(<= ); // NOLINT
|
INT_TYPE_COMPARISON_OP(<=); // NOLINT
|
||||||
INT_TYPE_COMPARISON_OP(> ); // NOLINT
|
INT_TYPE_COMPARISON_OP(>); // NOLINT
|
||||||
INT_TYPE_COMPARISON_OP(>= ); // NOLINT
|
INT_TYPE_COMPARISON_OP(>=); // NOLINT
|
||||||
#undef INT_TYPE_COMPARISON_OP
|
#undef INT_TYPE_COMPARISON_OP
|
||||||
|
|
||||||
} // namespace gtl
|
} // namespace gtl
|
||||||
|
@ -42,7 +42,8 @@ class IntTypeTest : public ::testing::Test {
|
|||||||
|
|
||||||
// All tests below will be executed on all supported IntTypes.
|
// All tests below will be executed on all supported IntTypes.
|
||||||
typedef ::testing::Types<Int8_IT, UInt8_IT, Int16_IT, UInt16_IT, Int32_IT,
|
typedef ::testing::Types<Int8_IT, UInt8_IT, Int16_IT, UInt16_IT, Int32_IT,
|
||||||
Int64_IT, UInt64_IT, Long_IT> SupportedIntTypes;
|
Int64_IT, UInt64_IT, Long_IT>
|
||||||
|
SupportedIntTypes;
|
||||||
|
|
||||||
TYPED_TEST_CASE(IntTypeTest, SupportedIntTypes);
|
TYPED_TEST_CASE(IntTypeTest, SupportedIntTypes);
|
||||||
|
|
||||||
@ -232,7 +233,8 @@ TYPED_TEST(IntTypeTest, TestOperators) {
|
|||||||
|
|
||||||
TYPED_TEST(IntTypeTest, TestHashFunctor) {
|
TYPED_TEST(IntTypeTest, TestHashFunctor) {
|
||||||
std::unordered_map<typename TestFixture::T, char,
|
std::unordered_map<typename TestFixture::T, char,
|
||||||
typename TestFixture::T::Hasher> map;
|
typename TestFixture::T::Hasher>
|
||||||
|
map;
|
||||||
typename TestFixture::T a(0);
|
typename TestFixture::T a(0);
|
||||||
map[a] = 'c';
|
map[a] = 'c';
|
||||||
EXPECT_EQ('c', map[a]);
|
EXPECT_EQ('c', map[a]);
|
||||||
|
@ -593,12 +593,12 @@ class optional : private internal_optional::optional_data<T>,
|
|||||||
assert(this->engaged_);
|
assert(this->engaged_);
|
||||||
return this->pointer();
|
return this->pointer();
|
||||||
}
|
}
|
||||||
constexpr const T& operator*() const & { return reference(); }
|
constexpr const T& operator*() const& { return reference(); }
|
||||||
T& operator*() & {
|
T& operator*() & {
|
||||||
assert(this->engaged_);
|
assert(this->engaged_);
|
||||||
return reference();
|
return reference();
|
||||||
}
|
}
|
||||||
constexpr const T&& operator*() const && { return std::move(reference()); }
|
constexpr const T&& operator*() const&& { return std::move(reference()); }
|
||||||
T&& operator*() && {
|
T&& operator*() && {
|
||||||
assert(this->engaged_);
|
assert(this->engaged_);
|
||||||
return std::move(reference());
|
return std::move(reference());
|
||||||
@ -621,7 +621,7 @@ class optional : private internal_optional::optional_data<T>,
|
|||||||
// Use `opt.value()` to get a reference to underlying value. The constness
|
// Use `opt.value()` to get a reference to underlying value. The constness
|
||||||
// and lvalue/rvalue-ness of `opt` is preserved to the view of the T
|
// and lvalue/rvalue-ness of `opt` is preserved to the view of the T
|
||||||
// subobject.
|
// subobject.
|
||||||
const T& value() const & {
|
const T& value() const& {
|
||||||
CHECK(*this) << "Bad optional access";
|
CHECK(*this) << "Bad optional access";
|
||||||
return reference();
|
return reference();
|
||||||
}
|
}
|
||||||
@ -633,7 +633,7 @@ class optional : private internal_optional::optional_data<T>,
|
|||||||
CHECK(*this) << "Bad optional access";
|
CHECK(*this) << "Bad optional access";
|
||||||
return std::move(reference());
|
return std::move(reference());
|
||||||
}
|
}
|
||||||
const T&& value() const && { // NOLINT(build/c++11)
|
const T&& value() const&& { // NOLINT(build/c++11)
|
||||||
CHECK(*this) << "Bad optional access";
|
CHECK(*this) << "Bad optional access";
|
||||||
return std::move(reference());
|
return std::move(reference());
|
||||||
}
|
}
|
||||||
@ -641,7 +641,7 @@ class optional : private internal_optional::optional_data<T>,
|
|||||||
// Use `opt.value_or(val)` to get either the value of T or the given default
|
// Use `opt.value_or(val)` to get either the value of T or the given default
|
||||||
// `val` in the empty case.
|
// `val` in the empty case.
|
||||||
template <class U>
|
template <class U>
|
||||||
constexpr T value_or(U&& v) const & {
|
constexpr T value_or(U&& v) const& {
|
||||||
return static_cast<bool>(*this) ? **this
|
return static_cast<bool>(*this) ? **this
|
||||||
: static_cast<T>(std::forward<U>(v));
|
: static_cast<T>(std::forward<U>(v));
|
||||||
}
|
}
|
||||||
@ -656,8 +656,8 @@ class optional : private internal_optional::optional_data<T>,
|
|||||||
constexpr const T& reference() const { return *this->pointer(); }
|
constexpr const T& reference() const { return *this->pointer(); }
|
||||||
T& reference() { return *(this->pointer()); }
|
T& reference() { return *(this->pointer()); }
|
||||||
|
|
||||||
// T constraint checks. You can't have an optional of nullopt_t, in_place_t or
|
// T constraint checks. You can't have an optional of nullopt_t, in_place_t
|
||||||
// a reference.
|
// or a reference.
|
||||||
static_assert(
|
static_assert(
|
||||||
!std::is_same<nullopt_t, typename std::remove_cv<T>::type>::value,
|
!std::is_same<nullopt_t, typename std::remove_cv<T>::type>::value,
|
||||||
"optional<nullopt_t> is not allowed.");
|
"optional<nullopt_t> is not allowed.");
|
||||||
|
@ -24,17 +24,29 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
using tensorflow::gtl::optional;
|
|
||||||
using tensorflow::gtl::nullopt;
|
|
||||||
using tensorflow::gtl::nullopt_t;
|
|
||||||
using tensorflow::gtl::in_place;
|
using tensorflow::gtl::in_place;
|
||||||
using tensorflow::gtl::in_place_t;
|
using tensorflow::gtl::in_place_t;
|
||||||
using tensorflow::gtl::make_optional;
|
using tensorflow::gtl::make_optional;
|
||||||
|
using tensorflow::gtl::nullopt;
|
||||||
|
using tensorflow::gtl::nullopt_t;
|
||||||
|
using tensorflow::gtl::optional;
|
||||||
|
|
||||||
template <typename T> string TypeQuals(T&) { return "&"; }
|
template <typename T>
|
||||||
template <typename T> string TypeQuals(T&&) { return "&&"; }
|
string TypeQuals(T&) {
|
||||||
template <typename T> string TypeQuals(const T&) { return "c&"; }
|
return "&";
|
||||||
template <typename T> string TypeQuals(const T&&) { return "c&&"; }
|
}
|
||||||
|
template <typename T>
|
||||||
|
string TypeQuals(T&&) {
|
||||||
|
return "&&";
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
string TypeQuals(const T&) {
|
||||||
|
return "c&";
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
string TypeQuals(const T&&) {
|
||||||
|
return "c&&";
|
||||||
|
}
|
||||||
|
|
||||||
struct StructorListener {
|
struct StructorListener {
|
||||||
int construct0 = 0;
|
int construct0 = 0;
|
||||||
|
@ -28,10 +28,10 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
using tensorflow::string;
|
||||||
using tensorflow::gtl::TopN;
|
using tensorflow::gtl::TopN;
|
||||||
using tensorflow::random::PhiloxRandom;
|
using tensorflow::random::PhiloxRandom;
|
||||||
using tensorflow::random::SimplePhilox;
|
using tensorflow::random::SimplePhilox;
|
||||||
using tensorflow::string;
|
|
||||||
|
|
||||||
// Move the contents from an owned raw pointer, returning by value.
|
// Move the contents from an owned raw pointer, returning by value.
|
||||||
// Objects are easier to manage by value.
|
// Objects are easier to manage by value.
|
||||||
|
@ -22,6 +22,6 @@ namespace compression {
|
|||||||
const char kNone[] = "";
|
const char kNone[] = "";
|
||||||
const char kGzip[] = "GZIP";
|
const char kGzip[] = "GZIP";
|
||||||
|
|
||||||
}
|
} // namespace compression
|
||||||
}
|
} // namespace io
|
||||||
}
|
} // namespace tensorflow
|
||||||
|
@ -23,8 +23,8 @@ namespace compression {
|
|||||||
extern const char kNone[];
|
extern const char kNone[];
|
||||||
extern const char kGzip[];
|
extern const char kGzip[];
|
||||||
|
|
||||||
}
|
} // namespace compression
|
||||||
}
|
} // namespace io
|
||||||
}
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_LIB_IO_COMPRESSION_H_
|
#endif // TENSORFLOW_CORE_LIB_IO_COMPRESSION_H_
|
||||||
|
@ -207,7 +207,7 @@ Status RecordReader::SkipNBytes(uint64 offset) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
} // namespace io
|
||||||
|
|
||||||
SequentialRecordReader::SequentialRecordReader(
|
SequentialRecordReader::SequentialRecordReader(
|
||||||
RandomAccessFile* file, const RecordReaderOptions& options)
|
RandomAccessFile* file, const RecordReaderOptions& options)
|
||||||
|
@ -218,8 +218,8 @@ TEST_F(RecordioTest, RandomRead) {
|
|||||||
|
|
||||||
// Tests of all the error paths in log_reader.cc follow:
|
// Tests of all the error paths in log_reader.cc follow:
|
||||||
static void AssertHasSubstr(StringPiece s, StringPiece expected) {
|
static void AssertHasSubstr(StringPiece s, StringPiece expected) {
|
||||||
EXPECT_TRUE(StringPiece(s).contains(expected)) << s << " does not contain "
|
EXPECT_TRUE(StringPiece(s).contains(expected))
|
||||||
<< expected;
|
<< s << " does not contain " << expected;
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(RecordioTest, ReadError) {
|
TEST_F(RecordioTest, ReadError) {
|
||||||
|
@ -197,8 +197,8 @@ bool CommonInitDecode(StringPiece png_string, int desired_channels,
|
|||||||
int desired_channel_bits, DecodeContext* context) {
|
int desired_channel_bits, DecodeContext* context) {
|
||||||
CHECK(desired_channel_bits == 8 || desired_channel_bits == 16)
|
CHECK(desired_channel_bits == 8 || desired_channel_bits == 16)
|
||||||
<< "desired_channel_bits = " << desired_channel_bits;
|
<< "desired_channel_bits = " << desired_channel_bits;
|
||||||
CHECK(0 <= desired_channels && desired_channels <= 4) << "desired_channels = "
|
CHECK(0 <= desired_channels && desired_channels <= 4)
|
||||||
<< desired_channels;
|
<< "desired_channels = " << desired_channels;
|
||||||
context->error_condition = false;
|
context->error_condition = false;
|
||||||
context->channels = desired_channels;
|
context->channels = desired_channels;
|
||||||
context->png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, context,
|
context->png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, context,
|
||||||
|
@ -35,8 +35,8 @@ void FillRandoms(PhiloxRandom gen, typename Distribution::ResultElementType* p,
|
|||||||
int64 size) {
|
int64 size) {
|
||||||
const int granularity = Distribution::kResultElementCount;
|
const int granularity = Distribution::kResultElementCount;
|
||||||
|
|
||||||
CHECK(size % granularity == 0) << " size: " << size
|
CHECK(size % granularity == 0)
|
||||||
<< " granularity: " << granularity;
|
<< " size: " << size << " granularity: " << granularity;
|
||||||
|
|
||||||
Distribution dist;
|
Distribution dist;
|
||||||
for (int i = 0; i < size; i += granularity) {
|
for (int i = 0; i < size; i += granularity) {
|
||||||
|
@ -17,8 +17,8 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_
|
#define TENSORFLOW_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_
|
||||||
|
|
||||||
#define _USE_MATH_DEFINES
|
#define _USE_MATH_DEFINES
|
||||||
#include <cmath>
|
|
||||||
#include <math.h>
|
#include <math.h>
|
||||||
|
#include <cmath>
|
||||||
#undef _USE_MATH_DEFINES
|
#undef _USE_MATH_DEFINES
|
||||||
|
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
@ -27,7 +27,6 @@ limitations under the License.
|
|||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
#include "tensorflow/core/lib/random/philox_random.h"
|
#include "tensorflow/core/lib/random/philox_random.h"
|
||||||
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace random {
|
namespace random {
|
||||||
|
|
||||||
|
@ -45,8 +45,8 @@ void FillRandomsWithSingles(PhiloxRandom gen,
|
|||||||
int64 size) {
|
int64 size) {
|
||||||
int granularity = Distribution::kResultElementCount;
|
int granularity = Distribution::kResultElementCount;
|
||||||
|
|
||||||
CHECK(size % granularity == 0) << " size: " << size
|
CHECK(size % granularity == 0)
|
||||||
<< " granularity: " << granularity;
|
<< " size: " << size << " granularity: " << granularity;
|
||||||
|
|
||||||
SingleSampleAdapter<PhiloxRandom> single_samples(&gen);
|
SingleSampleAdapter<PhiloxRandom> single_samples(&gen);
|
||||||
|
|
||||||
|
@ -472,7 +472,8 @@ void OrderedCode::WriteSignedNumIncreasing(string* dest, int64 val) {
|
|||||||
// buf = val in network byte order, sign extended to 10 bytes
|
// buf = val in network byte order, sign extended to 10 bytes
|
||||||
const char sign_byte = val < 0 ? '\xff' : '\0';
|
const char sign_byte = val < 0 ? '\xff' : '\0';
|
||||||
char buf[10] = {
|
char buf[10] = {
|
||||||
sign_byte, sign_byte,
|
sign_byte,
|
||||||
|
sign_byte,
|
||||||
};
|
};
|
||||||
StoreBigEndian64(buf + 2, val);
|
StoreBigEndian64(buf + 2, val);
|
||||||
static_assert(sizeof(buf) == kMaxSigned64Length, "max length size mismatch");
|
static_assert(sizeof(buf) == kMaxSigned64Length, "max length size mismatch");
|
||||||
|
@ -126,7 +126,7 @@ class AlphaNum {
|
|||||||
: piece_(digits_, strlen(DoubleToBuffer(f, digits_))) {}
|
: piece_(digits_, strlen(DoubleToBuffer(f, digits_))) {}
|
||||||
|
|
||||||
AlphaNum(const Eigen::half &f); // NOLINT(runtime/explicit)
|
AlphaNum(const Eigen::half &f); // NOLINT(runtime/explicit)
|
||||||
AlphaNum(Hex hex); // NOLINT(runtime/explicit)
|
AlphaNum(Hex hex); // NOLINT(runtime/explicit)
|
||||||
|
|
||||||
AlphaNum(const char *c_str) : piece_(c_str) {} // NOLINT(runtime/explicit)
|
AlphaNum(const char *c_str) : piece_(c_str) {} // NOLINT(runtime/explicit)
|
||||||
AlphaNum(const StringPiece &pc) : piece_(pc) {} // NOLINT(runtime/explicit)
|
AlphaNum(const StringPiece &pc) : piece_(pc) {} // NOLINT(runtime/explicit)
|
||||||
|
@ -25,8 +25,9 @@ namespace tensorflow {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
TEST(BackwardsCompatibilityTest, IsCompatible) {
|
TEST(BackwardsCompatibilityTest, IsCompatible) {
|
||||||
OpCompatibilityLib compatibility(
|
OpCompatibilityLib compatibility("tensorflow/core/ops",
|
||||||
"tensorflow/core/ops", strings::StrCat("v", TF_MAJOR_VERSION), nullptr);
|
strings::StrCat("v", TF_MAJOR_VERSION),
|
||||||
|
nullptr);
|
||||||
|
|
||||||
Env* env = Env::Default();
|
Env* env = Env::Default();
|
||||||
int changed_ops = 0;
|
int changed_ops = 0;
|
||||||
|
@ -18,9 +18,9 @@ limitations under the License.
|
|||||||
#include <arpa/inet.h>
|
#include <arpa/inet.h>
|
||||||
#include <netdb.h>
|
#include <netdb.h>
|
||||||
#else
|
#else
|
||||||
|
#include <Windows.h>
|
||||||
#include <winsock2.h>
|
#include <winsock2.h>
|
||||||
#include <ws2tcpip.h>
|
#include <ws2tcpip.h>
|
||||||
#include <Windows.h>
|
|
||||||
#endif
|
#endif
|
||||||
#include <sys/types.h>
|
#include <sys/types.h>
|
||||||
|
|
||||||
|
@ -38,8 +38,7 @@ class FakeHttpRequest : public CurlHttpRequest {
|
|||||||
public:
|
public:
|
||||||
/// Return the response for the given request.
|
/// Return the response for the given request.
|
||||||
FakeHttpRequest(const string& request, const string& response)
|
FakeHttpRequest(const string& request, const string& response)
|
||||||
: FakeHttpRequest(request, response, Status::OK(), nullptr, {}, 200) {
|
: FakeHttpRequest(request, response, Status::OK(), nullptr, {}, 200) {}
|
||||||
}
|
|
||||||
|
|
||||||
/// Return the response with headers for the given request.
|
/// Return the response with headers for the given request.
|
||||||
FakeHttpRequest(const string& request, const string& response,
|
FakeHttpRequest(const string& request, const string& response,
|
||||||
|
@ -160,12 +160,12 @@ TEST(OAuthClientTest, GetTokenFromServiceAccountJson) {
|
|||||||
ASSERT_EQ(1, EVP_DigestVerifyInit(md_ctx, nullptr, md, nullptr, key));
|
ASSERT_EQ(1, EVP_DigestVerifyInit(md_ctx, nullptr, md, nullptr, key));
|
||||||
ASSERT_EQ(1, EVP_DigestVerifyUpdate(md_ctx, header_dot_claim.c_str(),
|
ASSERT_EQ(1, EVP_DigestVerifyUpdate(md_ctx, header_dot_claim.c_str(),
|
||||||
header_dot_claim.size()));
|
header_dot_claim.size()));
|
||||||
ASSERT_EQ(
|
ASSERT_EQ(1,
|
||||||
1,
|
EVP_DigestVerifyFinal(
|
||||||
EVP_DigestVerifyFinal(
|
md_ctx,
|
||||||
md_ctx, const_cast<unsigned char*>(
|
const_cast<unsigned char*>(
|
||||||
reinterpret_cast<const unsigned char*>(signature.data())),
|
reinterpret_cast<const unsigned char*>(signature.data())),
|
||||||
signature.size()));
|
signature.size()));
|
||||||
EVP_MD_CTX_cleanup(md_ctx);
|
EVP_MD_CTX_cleanup(md_ctx);
|
||||||
|
|
||||||
// Free all the crypto-related resources.
|
// Free all the crypto-related resources.
|
||||||
|
@ -25,7 +25,6 @@ namespace tensorflow {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
|
||||||
class RetryingRandomAccessFile : public RandomAccessFile {
|
class RetryingRandomAccessFile : public RandomAccessFile {
|
||||||
public:
|
public:
|
||||||
RetryingRandomAccessFile(std::unique_ptr<RandomAccessFile> base_file,
|
RetryingRandomAccessFile(std::unique_ptr<RandomAccessFile> base_file,
|
||||||
|
@ -27,8 +27,7 @@ TEST(CudaLibdevicePathTest, LibdevicePath) {
|
|||||||
VLOG(2) << "Libdevice root = " << LibdeviceRoot();
|
VLOG(2) << "Libdevice root = " << LibdeviceRoot();
|
||||||
std::vector<string> libdevice_files;
|
std::vector<string> libdevice_files;
|
||||||
TF_EXPECT_OK(Env::Default()->GetMatchingPaths(
|
TF_EXPECT_OK(Env::Default()->GetMatchingPaths(
|
||||||
io::JoinPath(LibdeviceRoot(), "libdevice.*.bc"),
|
io::JoinPath(LibdeviceRoot(), "libdevice.*.bc"), &libdevice_files));
|
||||||
&libdevice_files));
|
|
||||||
EXPECT_LT(0, libdevice_files.size());
|
EXPECT_LT(0, libdevice_files.size());
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
@ -579,8 +579,10 @@ Status DeviceTracerImpl::Collect(StepStatsCollector *collector) {
|
|||||||
// TODO(pbar) Handle device IDs and prefix properly.
|
// TODO(pbar) Handle device IDs and prefix properly.
|
||||||
const string prefix = "";
|
const string prefix = "";
|
||||||
const int id = 0;
|
const int id = 0;
|
||||||
const string stream_device = strings::StrCat(prefix, "/device:GPU:", id, "/stream:");
|
const string stream_device =
|
||||||
const string memcpy_device = strings::StrCat(prefix, "/device:GPU:", id, "/memcpy");
|
strings::StrCat(prefix, "/device:GPU:", id, "/stream:");
|
||||||
|
const string memcpy_device =
|
||||||
|
strings::StrCat(prefix, "/device:GPU:", id, "/memcpy");
|
||||||
|
|
||||||
mutex_lock l2(trace_mu_);
|
mutex_lock l2(trace_mu_);
|
||||||
for (const auto &rec : kernel_records_) {
|
for (const auto &rec : kernel_records_) {
|
||||||
|
@ -83,15 +83,14 @@ void LogMessage::GenerateLogMessage() {
|
|||||||
const size_t time_buffer_size = 30;
|
const size_t time_buffer_size = 30;
|
||||||
char time_buffer[time_buffer_size];
|
char time_buffer[time_buffer_size];
|
||||||
strftime(time_buffer, time_buffer_size, "%Y-%m-%d %H:%M:%S",
|
strftime(time_buffer, time_buffer_size, "%Y-%m-%d %H:%M:%S",
|
||||||
localtime(&now_seconds));
|
localtime(&now_seconds));
|
||||||
|
|
||||||
// TODO(jeff,sanjay): Replace this with something that logs through the env.
|
// TODO(jeff,sanjay): Replace this with something that logs through the env.
|
||||||
fprintf(stderr, "%s.%06d: %c %s:%d] %s\n", time_buffer, micros_remainder,
|
fprintf(stderr, "%s.%06d: %c %s:%d] %s\n", time_buffer, micros_remainder,
|
||||||
"IWEF"[severity_], fname_, line_, str().c_str());
|
"IWEF"[severity_], fname_, line_, str().c_str());
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Parse log level (int64) from environment variable (char*)
|
// Parse log level (int64) from environment variable (char*)
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user