Add a grappler optimization (CPU) for pow(x, 3) -> mul(x, square(x)).

Benchmarks show mul+square to be about 13x to 50x faster than pow on CPUs:

name                                    time/op
BM_cpu_Cube_CubeWithPow3/4k               362?s ? 0%
BM_cpu_Cube_CubeWithPow3/64k              782?s ? 0%
BM_cpu_Cube_CubeWithPow3/1M               8.59ms ? 0%
BM_cpu_Cube_CubeWithTwoMuls/4k           11.4?s ? 3%
BM_cpu_Cube_CubeWithTwoMuls/64k          61.5?s ?12%
BM_cpu_Cube_CubeWithTwoMuls/1M            172?s ? 1%
BM_cpu_Cube_CubeWithMulSquare/4k         13.7?s ? 3%
BM_cpu_Cube_CubeWithMulSquare/64k        57.5?s ? 2%
BM_cpu_Cube_CubeWithMulSquare/1M          173?s ? 1%

PiperOrigin-RevId: 257149180
This commit is contained in:
A. Unique TensorFlower 2019-07-09 02:07:26 -07:00 committed by TensorFlower Gardener
parent 36ac0a1ac5
commit 7ad65ad902
3 changed files with 106 additions and 22 deletions

View File

@ -190,6 +190,13 @@ bool GetElementUnexhaustive(const Tensor& t, int i, const std::set<int>& dtypes,
} }
} }
bool NodeIsOnCpu(const NodeDef& node) {
string task;
string device;
return DeviceNameUtils::SplitDeviceName(node.device(), &task, &device) &&
absl::StrContains(device, DEVICE_CPU);
}
// Graph optimizer context extension specific to ArithmeticOptimizer. // Graph optimizer context extension specific to ArithmeticOptimizer.
struct ArithmeticOptimizerContext { struct ArithmeticOptimizerContext {
explicit ArithmeticOptimizerContext(SetVector<NodeDef*>* nodes_to_simplify) explicit ArithmeticOptimizerContext(SetVector<NodeDef*>* nodes_to_simplify)
@ -2361,13 +2368,7 @@ class ReplaceMulWithSquare : public ArithmeticOptimizerStage {
const DataType type = GetDataTypeFromAttr(*node, "T"); const DataType type = GetDataTypeFromAttr(*node, "T");
bool is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128); bool is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128);
string task; if (!is_complex || NodeIsOnCpu(*node)) {
string device;
bool is_on_cpu =
DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) &&
absl::StrContains(device, DEVICE_CPU);
if (!is_complex || is_on_cpu) {
NodeDef* new_square_node = AddCopyNode(optimized_node_name, node); NodeDef* new_square_node = AddCopyNode(optimized_node_name, node);
new_square_node->set_op("Square"); new_square_node->set_op("Square");
for (int i = 1; i < new_square_node->input_size(); ++i) { for (int i = 1; i < new_square_node->input_size(); ++i) {
@ -2528,6 +2529,30 @@ class ConvertPowStage : public ArithmeticOptimizerStage {
node->set_input(1, AsControlDependency(y->name())); node->set_input(1, AsControlDependency(y->name()));
AddToOptimizationQueue(node); AddToOptimizationQueue(node);
AddToOptimizationQueue(y); AddToOptimizationQueue(y);
} else if (curr == complex128(3, 0)) {
// TODO(courbet): Use 'Cube' when it's added to TF ops.
if (NodeIsOnCpu(*node)) {
// We create an inner square node: inner_square = square(x)
const NodeScopeAndName scope_and_name =
ParseNodeScopeAndName(node->name());
const string inner_square_name =
OptimizedNodeName(scope_and_name, "_inner");
NodeDef* inner_square_node = ctx().node_map->GetNode(inner_square_name);
if (inner_square_node == nullptr) {
inner_square_node = AddCopyNode(inner_square_name, node);
inner_square_node->set_op("Square");
inner_square_node->mutable_input()->RemoveLast();
}
ctx().node_map->AddOutput(x->name(), inner_square_node->name());
// We modify `node`: node = mul(x, inner_square);
node->set_op("Mul");
node->set_input(1, inner_square_node->name());
node->add_input(AsControlDependency(y->name()));
AddToOptimizationQueue(node);
AddToOptimizationQueue(inner_square_node);
AddToOptimizationQueue(y);
}
} else if (curr == complex128(1, 0) && } else if (curr == complex128(1, 0) &&
ShapesSymbolicallyEqual(value_props.shape(), output_shape)) { ShapesSymbolicallyEqual(value_props.shape(), output_shape)) {
// Pow could be used to broadcast, so make sure the shapes of the two // Pow could be used to broadcast, so make sure the shapes of the two
@ -2985,17 +3010,6 @@ class UnaryOpsComposition : public ArithmeticOptimizerStage {
DrivesControlDependency(node)); DrivesControlDependency(node));
} }
// UnaryOpsComposition is defined only for CPU.
bool NodeIsOnCpu(const NodeDef& node) const {
using absl::StartsWith;
string task;
string device;
return DeviceNameUtils::SplitDeviceName(node.device(), &task, &device) &&
StartsWith(device, DEVICE_CPU);
}
bool NodeIsAlreadyFused(const NodeDef& node) const { bool NodeIsAlreadyFused(const NodeDef& node) const {
return fused_nodes_.count(node.name()) > 0; return fused_nodes_.count(node.name()) > 0;
} }

View File

@ -2728,6 +2728,7 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope(); tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2}); auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
auto y2 = ops::Const(s.WithOpName("y2"), {2.0f, 2.0f}, {1, 2}); auto y2 = ops::Const(s.WithOpName("y2"), {2.0f, 2.0f}, {1, 2});
auto y3 = ops::Const(s.WithOpName("y3"), {3.0f, 3.0f}, {1, 2});
auto y1 = ops::Const(s.WithOpName("y1"), {1.0f, 1.0f}, {1, 2}); auto y1 = ops::Const(s.WithOpName("y1"), {1.0f, 1.0f}, {1, 2});
auto yPoint5 = ops::Const(s.WithOpName("y.5"), {0.5f, 0.5f}, {1, 2}); auto yPoint5 = ops::Const(s.WithOpName("y.5"), {0.5f, 0.5f}, {1, 2});
auto y0 = ops::Const(s.WithOpName("y0"), {0.0f, 0.0f}, {1, 2}); auto y0 = ops::Const(s.WithOpName("y0"), {0.0f, 0.0f}, {1, 2});
@ -2738,6 +2739,8 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) {
auto ones = ops::Const(s.WithOpName("ones"), {1.0f, 1.0f, 1.0f}, {1, 3}); auto ones = ops::Const(s.WithOpName("ones"), {1.0f, 1.0f, 1.0f}, {1, 3});
auto zeros = ops::Const(s.WithOpName("zeros"), {0.0f, 0.0f, 0.0f}, {1, 3}); auto zeros = ops::Const(s.WithOpName("zeros"), {0.0f, 0.0f, 0.0f}, {1, 3});
Output out2 = ops::Pow(s.WithOpName("out2"), x, y2); Output out2 = ops::Pow(s.WithOpName("out2"), x, y2);
Output out3 =
ops::Pow(s.WithOpName("out3").WithDevice("/device:CPU:0"), x, y3);
Output out1 = ops::Pow(s.WithOpName("out1"), x, y1); Output out1 = ops::Pow(s.WithOpName("out1"), x, y1);
Output outPoint5 = ops::Pow(s.WithOpName("out.5"), x, yPoint5); Output outPoint5 = ops::Pow(s.WithOpName("out.5"), x, yPoint5);
Output out0 = ops::Pow(s.WithOpName("out0"), x, y0); Output out0 = ops::Pow(s.WithOpName("out0"), x, y0);
@ -2748,18 +2751,18 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) {
Output out_bcast2 = ops::Pow(s.WithOpName("out_bcast2"), z, zeros); Output out_bcast2 = ops::Pow(s.WithOpName("out_bcast2"), z, zeros);
GrapplerItem item; GrapplerItem item;
item.fetch = {"out2", "out1", "out.5", "out0", "out_.5", item.fetch = {"out2", "out3", "out1", "out.5", "out0",
"out_1", "out", "out_bcast1", "out_bcast2"}; "out_.5", "out_1", "out", "out_bcast1", "out_bcast2"};
TF_CHECK_OK(s.ToGraphDef(&item.graph)); TF_CHECK_OK(s.ToGraphDef(&item.graph));
auto tensors_expected = EvaluateNodes(item.graph, item.fetch); auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
ASSERT_EQ(tensors_expected.size(), 9); ASSERT_EQ(tensors_expected.size(), 10);
GraphDef got; GraphDef got;
ArithmeticOptimizer optimizer; ArithmeticOptimizer optimizer;
EnableOnlyConvertPow(&optimizer); EnableOnlyConvertPow(&optimizer);
OptimizeAndPrune(&optimizer, &item, &got); OptimizeAndPrune(&optimizer, &item, &got);
auto tensors = EvaluateNodes(got, item.fetch); auto tensors = EvaluateNodes(got, item.fetch);
ASSERT_EQ(tensors.size(), 9); ASSERT_EQ(tensors.size(), 10);
for (int i = 0; i < tensors.size(); ++i) { for (int i = 0; i < tensors.size(); ++i) {
EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements()); EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
@ -2773,6 +2776,12 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) {
AddNode("ones", "Const", {}, {}, &want); AddNode("ones", "Const", {}, {}, &want);
AddNode("zeros", "Const", {}, {}, &want); AddNode("zeros", "Const", {}, {}, &want);
AddNode("out2", "Square", {"x"}, {}, &want); AddNode("out2", "Square", {"x"}, {}, &want);
AddNode("ArithmeticOptimizer/ConvertPow__inner_out3", "Square", {"x"}, {},
&want)
->set_device("/device:CPU:0");
AddNode("out3", "Mul", {"x", "ArithmeticOptimizer/ConvertPow__inner_out3"},
{}, &want)
->set_device("/device:CPU:0");
AddNode("out1", "Identity", {"x"}, {}, &want); AddNode("out1", "Identity", {"x"}, {}, &want);
AddNode("out.5", "Sqrt", {"x"}, {}, &want); AddNode("out.5", "Sqrt", {"x"}, {}, &want);
AddNode("out0", "Const", {AsControlDependency("x")}, {}, &want); AddNode("out0", "Const", {AsControlDependency("x")}, {}, &want);

View File

@ -147,6 +147,67 @@ BM_BINARY_SCALAR(sycl, DivNoNan);
#undef BM_BINARY_SCALAR #undef BM_BINARY_SCALAR
// Three implementations of x^3.
Graph* CubeWithPow3(int num) {
Graph* g = new Graph(OpRegistry::Global());
Tensor lhs(DT_FLOAT, TensorShape({64, 64, num / (64 * 64)}));
lhs.flat<float>().setRandom();
Tensor rhs(DT_FLOAT, TensorShape({}));
rhs.flat<float>().setConstant(3);
test::graph::Binary(g, "Pow", test::graph::Constant(g, lhs),
test::graph::Constant(g, rhs));
return g;
}
Graph* CubeWithTwoMuls(int num) {
Graph* g = new Graph(OpRegistry::Global());
Tensor lhs(DT_FLOAT, TensorShape({64, 64, num / (64 * 64)}));
lhs.flat<float>().setRandom();
auto* x = test::graph::Constant(g, lhs);
auto* inner = test::graph::Binary(g, "Mul", x, x);
test::graph::Binary(g, "Mul", x, inner);
return g;
}
Graph* CubeWithMulSquare(int num) {
Graph* g = new Graph(OpRegistry::Global());
Tensor lhs(DT_FLOAT, TensorShape({64, 64, num / (64 * 64)}));
lhs.flat<float>().setRandom();
auto* x = test::graph::Constant(g, lhs);
auto* inner = test::graph::Unary(g, "Square", x);
test::graph::Binary(g, "Mul", test::graph::Constant(g, lhs), inner);
return g;
}
#define BM_CUBE(DEVICE, Impl) \
void BM_##DEVICE##_Cube_##Impl(int iters, int num) { \
const int64 tot = static_cast<int64>(iters) * num; \
testing::UseRealTime(); \
testing::ItemsProcessed(tot); \
testing::BytesProcessed(tot * sizeof(float)); \
test::Benchmark(#DEVICE, Impl(num)).Run(iters); \
} \
BENCHMARK(BM_##DEVICE##_Cube_##Impl) \
->Arg(1 << 12) /* must >= 4096 */ \
->Arg(1 << 16) \
->Arg(1 << 20);
BM_CUBE(cpu, CubeWithPow3);
BM_CUBE(cpu, CubeWithTwoMuls);
BM_CUBE(cpu, CubeWithMulSquare);
#if GOOGLE_CUDA
BM_CUBE(gpu, CubeWithPow3);
BM_CUBE(gpu, CubeWithTwoMuls);
BM_CUBE(gpu, CubeWithMulSquare);
#endif // GOOGLE_CUDA
#ifdef TENSORFLOW_USE_SYCL
BM_CUBE(sycl, CubeWithPow3);
BM_CUBE(sycl, CubeWithTwoMuls);
BM_CUBE(sycl, CubeWithMulSquare);
#endif // TENSORFLOW_USE_SYCL
#undef BM_CUBE
template <class T> template <class T>
Graph* BiasAdd(int rows, int cols, DataType type) { Graph* BiasAdd(int rows, int cols, DataType type) {
Graph* g = new Graph(OpRegistry::Global()); Graph* g = new Graph(OpRegistry::Global());