Add a grappler optimization (CPU) for pow(x, 3)
-> mul(x, square(x))
.
Benchmarks show mul+square to be about 13x to 50x faster than pow on CPUs: name time/op BM_cpu_Cube_CubeWithPow3/4k 362?s ? 0% BM_cpu_Cube_CubeWithPow3/64k 782?s ? 0% BM_cpu_Cube_CubeWithPow3/1M 8.59ms ? 0% BM_cpu_Cube_CubeWithTwoMuls/4k 11.4?s ? 3% BM_cpu_Cube_CubeWithTwoMuls/64k 61.5?s ?12% BM_cpu_Cube_CubeWithTwoMuls/1M 172?s ? 1% BM_cpu_Cube_CubeWithMulSquare/4k 13.7?s ? 3% BM_cpu_Cube_CubeWithMulSquare/64k 57.5?s ? 2% BM_cpu_Cube_CubeWithMulSquare/1M 173?s ? 1% PiperOrigin-RevId: 257149180
This commit is contained in:
parent
36ac0a1ac5
commit
7ad65ad902
@ -190,6 +190,13 @@ bool GetElementUnexhaustive(const Tensor& t, int i, const std::set<int>& dtypes,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool NodeIsOnCpu(const NodeDef& node) {
|
||||||
|
string task;
|
||||||
|
string device;
|
||||||
|
return DeviceNameUtils::SplitDeviceName(node.device(), &task, &device) &&
|
||||||
|
absl::StrContains(device, DEVICE_CPU);
|
||||||
|
}
|
||||||
|
|
||||||
// Graph optimizer context extension specific to ArithmeticOptimizer.
|
// Graph optimizer context extension specific to ArithmeticOptimizer.
|
||||||
struct ArithmeticOptimizerContext {
|
struct ArithmeticOptimizerContext {
|
||||||
explicit ArithmeticOptimizerContext(SetVector<NodeDef*>* nodes_to_simplify)
|
explicit ArithmeticOptimizerContext(SetVector<NodeDef*>* nodes_to_simplify)
|
||||||
@ -2361,13 +2368,7 @@ class ReplaceMulWithSquare : public ArithmeticOptimizerStage {
|
|||||||
const DataType type = GetDataTypeFromAttr(*node, "T");
|
const DataType type = GetDataTypeFromAttr(*node, "T");
|
||||||
bool is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128);
|
bool is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128);
|
||||||
|
|
||||||
string task;
|
if (!is_complex || NodeIsOnCpu(*node)) {
|
||||||
string device;
|
|
||||||
bool is_on_cpu =
|
|
||||||
DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) &&
|
|
||||||
absl::StrContains(device, DEVICE_CPU);
|
|
||||||
|
|
||||||
if (!is_complex || is_on_cpu) {
|
|
||||||
NodeDef* new_square_node = AddCopyNode(optimized_node_name, node);
|
NodeDef* new_square_node = AddCopyNode(optimized_node_name, node);
|
||||||
new_square_node->set_op("Square");
|
new_square_node->set_op("Square");
|
||||||
for (int i = 1; i < new_square_node->input_size(); ++i) {
|
for (int i = 1; i < new_square_node->input_size(); ++i) {
|
||||||
@ -2528,6 +2529,30 @@ class ConvertPowStage : public ArithmeticOptimizerStage {
|
|||||||
node->set_input(1, AsControlDependency(y->name()));
|
node->set_input(1, AsControlDependency(y->name()));
|
||||||
AddToOptimizationQueue(node);
|
AddToOptimizationQueue(node);
|
||||||
AddToOptimizationQueue(y);
|
AddToOptimizationQueue(y);
|
||||||
|
} else if (curr == complex128(3, 0)) {
|
||||||
|
// TODO(courbet): Use 'Cube' when it's added to TF ops.
|
||||||
|
if (NodeIsOnCpu(*node)) {
|
||||||
|
// We create an inner square node: inner_square = square(x)
|
||||||
|
const NodeScopeAndName scope_and_name =
|
||||||
|
ParseNodeScopeAndName(node->name());
|
||||||
|
const string inner_square_name =
|
||||||
|
OptimizedNodeName(scope_and_name, "_inner");
|
||||||
|
NodeDef* inner_square_node = ctx().node_map->GetNode(inner_square_name);
|
||||||
|
if (inner_square_node == nullptr) {
|
||||||
|
inner_square_node = AddCopyNode(inner_square_name, node);
|
||||||
|
inner_square_node->set_op("Square");
|
||||||
|
inner_square_node->mutable_input()->RemoveLast();
|
||||||
|
}
|
||||||
|
ctx().node_map->AddOutput(x->name(), inner_square_node->name());
|
||||||
|
// We modify `node`: node = mul(x, inner_square);
|
||||||
|
node->set_op("Mul");
|
||||||
|
node->set_input(1, inner_square_node->name());
|
||||||
|
node->add_input(AsControlDependency(y->name()));
|
||||||
|
|
||||||
|
AddToOptimizationQueue(node);
|
||||||
|
AddToOptimizationQueue(inner_square_node);
|
||||||
|
AddToOptimizationQueue(y);
|
||||||
|
}
|
||||||
} else if (curr == complex128(1, 0) &&
|
} else if (curr == complex128(1, 0) &&
|
||||||
ShapesSymbolicallyEqual(value_props.shape(), output_shape)) {
|
ShapesSymbolicallyEqual(value_props.shape(), output_shape)) {
|
||||||
// Pow could be used to broadcast, so make sure the shapes of the two
|
// Pow could be used to broadcast, so make sure the shapes of the two
|
||||||
@ -2985,17 +3010,6 @@ class UnaryOpsComposition : public ArithmeticOptimizerStage {
|
|||||||
DrivesControlDependency(node));
|
DrivesControlDependency(node));
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnaryOpsComposition is defined only for CPU.
|
|
||||||
bool NodeIsOnCpu(const NodeDef& node) const {
|
|
||||||
using absl::StartsWith;
|
|
||||||
|
|
||||||
string task;
|
|
||||||
string device;
|
|
||||||
|
|
||||||
return DeviceNameUtils::SplitDeviceName(node.device(), &task, &device) &&
|
|
||||||
StartsWith(device, DEVICE_CPU);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool NodeIsAlreadyFused(const NodeDef& node) const {
|
bool NodeIsAlreadyFused(const NodeDef& node) const {
|
||||||
return fused_nodes_.count(node.name()) > 0;
|
return fused_nodes_.count(node.name()) > 0;
|
||||||
}
|
}
|
||||||
|
@ -2728,6 +2728,7 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) {
|
|||||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
|
auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
|
||||||
auto y2 = ops::Const(s.WithOpName("y2"), {2.0f, 2.0f}, {1, 2});
|
auto y2 = ops::Const(s.WithOpName("y2"), {2.0f, 2.0f}, {1, 2});
|
||||||
|
auto y3 = ops::Const(s.WithOpName("y3"), {3.0f, 3.0f}, {1, 2});
|
||||||
auto y1 = ops::Const(s.WithOpName("y1"), {1.0f, 1.0f}, {1, 2});
|
auto y1 = ops::Const(s.WithOpName("y1"), {1.0f, 1.0f}, {1, 2});
|
||||||
auto yPoint5 = ops::Const(s.WithOpName("y.5"), {0.5f, 0.5f}, {1, 2});
|
auto yPoint5 = ops::Const(s.WithOpName("y.5"), {0.5f, 0.5f}, {1, 2});
|
||||||
auto y0 = ops::Const(s.WithOpName("y0"), {0.0f, 0.0f}, {1, 2});
|
auto y0 = ops::Const(s.WithOpName("y0"), {0.0f, 0.0f}, {1, 2});
|
||||||
@ -2738,6 +2739,8 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) {
|
|||||||
auto ones = ops::Const(s.WithOpName("ones"), {1.0f, 1.0f, 1.0f}, {1, 3});
|
auto ones = ops::Const(s.WithOpName("ones"), {1.0f, 1.0f, 1.0f}, {1, 3});
|
||||||
auto zeros = ops::Const(s.WithOpName("zeros"), {0.0f, 0.0f, 0.0f}, {1, 3});
|
auto zeros = ops::Const(s.WithOpName("zeros"), {0.0f, 0.0f, 0.0f}, {1, 3});
|
||||||
Output out2 = ops::Pow(s.WithOpName("out2"), x, y2);
|
Output out2 = ops::Pow(s.WithOpName("out2"), x, y2);
|
||||||
|
Output out3 =
|
||||||
|
ops::Pow(s.WithOpName("out3").WithDevice("/device:CPU:0"), x, y3);
|
||||||
Output out1 = ops::Pow(s.WithOpName("out1"), x, y1);
|
Output out1 = ops::Pow(s.WithOpName("out1"), x, y1);
|
||||||
Output outPoint5 = ops::Pow(s.WithOpName("out.5"), x, yPoint5);
|
Output outPoint5 = ops::Pow(s.WithOpName("out.5"), x, yPoint5);
|
||||||
Output out0 = ops::Pow(s.WithOpName("out0"), x, y0);
|
Output out0 = ops::Pow(s.WithOpName("out0"), x, y0);
|
||||||
@ -2748,18 +2751,18 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) {
|
|||||||
Output out_bcast2 = ops::Pow(s.WithOpName("out_bcast2"), z, zeros);
|
Output out_bcast2 = ops::Pow(s.WithOpName("out_bcast2"), z, zeros);
|
||||||
|
|
||||||
GrapplerItem item;
|
GrapplerItem item;
|
||||||
item.fetch = {"out2", "out1", "out.5", "out0", "out_.5",
|
item.fetch = {"out2", "out3", "out1", "out.5", "out0",
|
||||||
"out_1", "out", "out_bcast1", "out_bcast2"};
|
"out_.5", "out_1", "out", "out_bcast1", "out_bcast2"};
|
||||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
|
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
|
||||||
ASSERT_EQ(tensors_expected.size(), 9);
|
ASSERT_EQ(tensors_expected.size(), 10);
|
||||||
|
|
||||||
GraphDef got;
|
GraphDef got;
|
||||||
ArithmeticOptimizer optimizer;
|
ArithmeticOptimizer optimizer;
|
||||||
EnableOnlyConvertPow(&optimizer);
|
EnableOnlyConvertPow(&optimizer);
|
||||||
OptimizeAndPrune(&optimizer, &item, &got);
|
OptimizeAndPrune(&optimizer, &item, &got);
|
||||||
auto tensors = EvaluateNodes(got, item.fetch);
|
auto tensors = EvaluateNodes(got, item.fetch);
|
||||||
ASSERT_EQ(tensors.size(), 9);
|
ASSERT_EQ(tensors.size(), 10);
|
||||||
|
|
||||||
for (int i = 0; i < tensors.size(); ++i) {
|
for (int i = 0; i < tensors.size(); ++i) {
|
||||||
EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
|
EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
|
||||||
@ -2773,6 +2776,12 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) {
|
|||||||
AddNode("ones", "Const", {}, {}, &want);
|
AddNode("ones", "Const", {}, {}, &want);
|
||||||
AddNode("zeros", "Const", {}, {}, &want);
|
AddNode("zeros", "Const", {}, {}, &want);
|
||||||
AddNode("out2", "Square", {"x"}, {}, &want);
|
AddNode("out2", "Square", {"x"}, {}, &want);
|
||||||
|
AddNode("ArithmeticOptimizer/ConvertPow__inner_out3", "Square", {"x"}, {},
|
||||||
|
&want)
|
||||||
|
->set_device("/device:CPU:0");
|
||||||
|
AddNode("out3", "Mul", {"x", "ArithmeticOptimizer/ConvertPow__inner_out3"},
|
||||||
|
{}, &want)
|
||||||
|
->set_device("/device:CPU:0");
|
||||||
AddNode("out1", "Identity", {"x"}, {}, &want);
|
AddNode("out1", "Identity", {"x"}, {}, &want);
|
||||||
AddNode("out.5", "Sqrt", {"x"}, {}, &want);
|
AddNode("out.5", "Sqrt", {"x"}, {}, &want);
|
||||||
AddNode("out0", "Const", {AsControlDependency("x")}, {}, &want);
|
AddNode("out0", "Const", {AsControlDependency("x")}, {}, &want);
|
||||||
|
@ -147,6 +147,67 @@ BM_BINARY_SCALAR(sycl, DivNoNan);
|
|||||||
|
|
||||||
#undef BM_BINARY_SCALAR
|
#undef BM_BINARY_SCALAR
|
||||||
|
|
||||||
|
// Three implementations of x^3.
|
||||||
|
Graph* CubeWithPow3(int num) {
|
||||||
|
Graph* g = new Graph(OpRegistry::Global());
|
||||||
|
Tensor lhs(DT_FLOAT, TensorShape({64, 64, num / (64 * 64)}));
|
||||||
|
lhs.flat<float>().setRandom();
|
||||||
|
Tensor rhs(DT_FLOAT, TensorShape({}));
|
||||||
|
rhs.flat<float>().setConstant(3);
|
||||||
|
test::graph::Binary(g, "Pow", test::graph::Constant(g, lhs),
|
||||||
|
test::graph::Constant(g, rhs));
|
||||||
|
return g;
|
||||||
|
}
|
||||||
|
|
||||||
|
Graph* CubeWithTwoMuls(int num) {
|
||||||
|
Graph* g = new Graph(OpRegistry::Global());
|
||||||
|
Tensor lhs(DT_FLOAT, TensorShape({64, 64, num / (64 * 64)}));
|
||||||
|
lhs.flat<float>().setRandom();
|
||||||
|
auto* x = test::graph::Constant(g, lhs);
|
||||||
|
auto* inner = test::graph::Binary(g, "Mul", x, x);
|
||||||
|
test::graph::Binary(g, "Mul", x, inner);
|
||||||
|
return g;
|
||||||
|
}
|
||||||
|
|
||||||
|
Graph* CubeWithMulSquare(int num) {
|
||||||
|
Graph* g = new Graph(OpRegistry::Global());
|
||||||
|
Tensor lhs(DT_FLOAT, TensorShape({64, 64, num / (64 * 64)}));
|
||||||
|
lhs.flat<float>().setRandom();
|
||||||
|
auto* x = test::graph::Constant(g, lhs);
|
||||||
|
auto* inner = test::graph::Unary(g, "Square", x);
|
||||||
|
test::graph::Binary(g, "Mul", test::graph::Constant(g, lhs), inner);
|
||||||
|
return g;
|
||||||
|
}
|
||||||
|
|
||||||
|
#define BM_CUBE(DEVICE, Impl) \
|
||||||
|
void BM_##DEVICE##_Cube_##Impl(int iters, int num) { \
|
||||||
|
const int64 tot = static_cast<int64>(iters) * num; \
|
||||||
|
testing::UseRealTime(); \
|
||||||
|
testing::ItemsProcessed(tot); \
|
||||||
|
testing::BytesProcessed(tot * sizeof(float)); \
|
||||||
|
test::Benchmark(#DEVICE, Impl(num)).Run(iters); \
|
||||||
|
} \
|
||||||
|
BENCHMARK(BM_##DEVICE##_Cube_##Impl) \
|
||||||
|
->Arg(1 << 12) /* must >= 4096 */ \
|
||||||
|
->Arg(1 << 16) \
|
||||||
|
->Arg(1 << 20);
|
||||||
|
|
||||||
|
BM_CUBE(cpu, CubeWithPow3);
|
||||||
|
BM_CUBE(cpu, CubeWithTwoMuls);
|
||||||
|
BM_CUBE(cpu, CubeWithMulSquare);
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
BM_CUBE(gpu, CubeWithPow3);
|
||||||
|
BM_CUBE(gpu, CubeWithTwoMuls);
|
||||||
|
BM_CUBE(gpu, CubeWithMulSquare);
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
#ifdef TENSORFLOW_USE_SYCL
|
||||||
|
BM_CUBE(sycl, CubeWithPow3);
|
||||||
|
BM_CUBE(sycl, CubeWithTwoMuls);
|
||||||
|
BM_CUBE(sycl, CubeWithMulSquare);
|
||||||
|
#endif // TENSORFLOW_USE_SYCL
|
||||||
|
|
||||||
|
#undef BM_CUBE
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
Graph* BiasAdd(int rows, int cols, DataType type) {
|
Graph* BiasAdd(int rows, int cols, DataType type) {
|
||||||
Graph* g = new Graph(OpRegistry::Global());
|
Graph* g = new Graph(OpRegistry::Global());
|
||||||
|
Loading…
Reference in New Issue
Block a user