Add a grappler optimization (CPU) for pow(x, 3)
-> mul(x, square(x))
.
Benchmarks show mul+square to be about 13x to 50x faster than pow on CPUs: name time/op BM_cpu_Cube_CubeWithPow3/4k 362?s ? 0% BM_cpu_Cube_CubeWithPow3/64k 782?s ? 0% BM_cpu_Cube_CubeWithPow3/1M 8.59ms ? 0% BM_cpu_Cube_CubeWithTwoMuls/4k 11.4?s ? 3% BM_cpu_Cube_CubeWithTwoMuls/64k 61.5?s ?12% BM_cpu_Cube_CubeWithTwoMuls/1M 172?s ? 1% BM_cpu_Cube_CubeWithMulSquare/4k 13.7?s ? 3% BM_cpu_Cube_CubeWithMulSquare/64k 57.5?s ? 2% BM_cpu_Cube_CubeWithMulSquare/1M 173?s ? 1% PiperOrigin-RevId: 257149180
This commit is contained in:
parent
36ac0a1ac5
commit
7ad65ad902
@ -190,6 +190,13 @@ bool GetElementUnexhaustive(const Tensor& t, int i, const std::set<int>& dtypes,
|
||||
}
|
||||
}
|
||||
|
||||
bool NodeIsOnCpu(const NodeDef& node) {
|
||||
string task;
|
||||
string device;
|
||||
return DeviceNameUtils::SplitDeviceName(node.device(), &task, &device) &&
|
||||
absl::StrContains(device, DEVICE_CPU);
|
||||
}
|
||||
|
||||
// Graph optimizer context extension specific to ArithmeticOptimizer.
|
||||
struct ArithmeticOptimizerContext {
|
||||
explicit ArithmeticOptimizerContext(SetVector<NodeDef*>* nodes_to_simplify)
|
||||
@ -2361,13 +2368,7 @@ class ReplaceMulWithSquare : public ArithmeticOptimizerStage {
|
||||
const DataType type = GetDataTypeFromAttr(*node, "T");
|
||||
bool is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128);
|
||||
|
||||
string task;
|
||||
string device;
|
||||
bool is_on_cpu =
|
||||
DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) &&
|
||||
absl::StrContains(device, DEVICE_CPU);
|
||||
|
||||
if (!is_complex || is_on_cpu) {
|
||||
if (!is_complex || NodeIsOnCpu(*node)) {
|
||||
NodeDef* new_square_node = AddCopyNode(optimized_node_name, node);
|
||||
new_square_node->set_op("Square");
|
||||
for (int i = 1; i < new_square_node->input_size(); ++i) {
|
||||
@ -2528,6 +2529,30 @@ class ConvertPowStage : public ArithmeticOptimizerStage {
|
||||
node->set_input(1, AsControlDependency(y->name()));
|
||||
AddToOptimizationQueue(node);
|
||||
AddToOptimizationQueue(y);
|
||||
} else if (curr == complex128(3, 0)) {
|
||||
// TODO(courbet): Use 'Cube' when it's added to TF ops.
|
||||
if (NodeIsOnCpu(*node)) {
|
||||
// We create an inner square node: inner_square = square(x)
|
||||
const NodeScopeAndName scope_and_name =
|
||||
ParseNodeScopeAndName(node->name());
|
||||
const string inner_square_name =
|
||||
OptimizedNodeName(scope_and_name, "_inner");
|
||||
NodeDef* inner_square_node = ctx().node_map->GetNode(inner_square_name);
|
||||
if (inner_square_node == nullptr) {
|
||||
inner_square_node = AddCopyNode(inner_square_name, node);
|
||||
inner_square_node->set_op("Square");
|
||||
inner_square_node->mutable_input()->RemoveLast();
|
||||
}
|
||||
ctx().node_map->AddOutput(x->name(), inner_square_node->name());
|
||||
// We modify `node`: node = mul(x, inner_square);
|
||||
node->set_op("Mul");
|
||||
node->set_input(1, inner_square_node->name());
|
||||
node->add_input(AsControlDependency(y->name()));
|
||||
|
||||
AddToOptimizationQueue(node);
|
||||
AddToOptimizationQueue(inner_square_node);
|
||||
AddToOptimizationQueue(y);
|
||||
}
|
||||
} else if (curr == complex128(1, 0) &&
|
||||
ShapesSymbolicallyEqual(value_props.shape(), output_shape)) {
|
||||
// Pow could be used to broadcast, so make sure the shapes of the two
|
||||
@ -2985,17 +3010,6 @@ class UnaryOpsComposition : public ArithmeticOptimizerStage {
|
||||
DrivesControlDependency(node));
|
||||
}
|
||||
|
||||
// UnaryOpsComposition is defined only for CPU.
|
||||
bool NodeIsOnCpu(const NodeDef& node) const {
|
||||
using absl::StartsWith;
|
||||
|
||||
string task;
|
||||
string device;
|
||||
|
||||
return DeviceNameUtils::SplitDeviceName(node.device(), &task, &device) &&
|
||||
StartsWith(device, DEVICE_CPU);
|
||||
}
|
||||
|
||||
bool NodeIsAlreadyFused(const NodeDef& node) const {
|
||||
return fused_nodes_.count(node.name()) > 0;
|
||||
}
|
||||
|
@ -2728,6 +2728,7 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
|
||||
auto y2 = ops::Const(s.WithOpName("y2"), {2.0f, 2.0f}, {1, 2});
|
||||
auto y3 = ops::Const(s.WithOpName("y3"), {3.0f, 3.0f}, {1, 2});
|
||||
auto y1 = ops::Const(s.WithOpName("y1"), {1.0f, 1.0f}, {1, 2});
|
||||
auto yPoint5 = ops::Const(s.WithOpName("y.5"), {0.5f, 0.5f}, {1, 2});
|
||||
auto y0 = ops::Const(s.WithOpName("y0"), {0.0f, 0.0f}, {1, 2});
|
||||
@ -2738,6 +2739,8 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) {
|
||||
auto ones = ops::Const(s.WithOpName("ones"), {1.0f, 1.0f, 1.0f}, {1, 3});
|
||||
auto zeros = ops::Const(s.WithOpName("zeros"), {0.0f, 0.0f, 0.0f}, {1, 3});
|
||||
Output out2 = ops::Pow(s.WithOpName("out2"), x, y2);
|
||||
Output out3 =
|
||||
ops::Pow(s.WithOpName("out3").WithDevice("/device:CPU:0"), x, y3);
|
||||
Output out1 = ops::Pow(s.WithOpName("out1"), x, y1);
|
||||
Output outPoint5 = ops::Pow(s.WithOpName("out.5"), x, yPoint5);
|
||||
Output out0 = ops::Pow(s.WithOpName("out0"), x, y0);
|
||||
@ -2748,18 +2751,18 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) {
|
||||
Output out_bcast2 = ops::Pow(s.WithOpName("out_bcast2"), z, zeros);
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"out2", "out1", "out.5", "out0", "out_.5",
|
||||
"out_1", "out", "out_bcast1", "out_bcast2"};
|
||||
item.fetch = {"out2", "out3", "out1", "out.5", "out0",
|
||||
"out_.5", "out_1", "out", "out_bcast1", "out_bcast2"};
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
|
||||
ASSERT_EQ(tensors_expected.size(), 9);
|
||||
ASSERT_EQ(tensors_expected.size(), 10);
|
||||
|
||||
GraphDef got;
|
||||
ArithmeticOptimizer optimizer;
|
||||
EnableOnlyConvertPow(&optimizer);
|
||||
OptimizeAndPrune(&optimizer, &item, &got);
|
||||
auto tensors = EvaluateNodes(got, item.fetch);
|
||||
ASSERT_EQ(tensors.size(), 9);
|
||||
ASSERT_EQ(tensors.size(), 10);
|
||||
|
||||
for (int i = 0; i < tensors.size(); ++i) {
|
||||
EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
|
||||
@ -2773,6 +2776,12 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) {
|
||||
AddNode("ones", "Const", {}, {}, &want);
|
||||
AddNode("zeros", "Const", {}, {}, &want);
|
||||
AddNode("out2", "Square", {"x"}, {}, &want);
|
||||
AddNode("ArithmeticOptimizer/ConvertPow__inner_out3", "Square", {"x"}, {},
|
||||
&want)
|
||||
->set_device("/device:CPU:0");
|
||||
AddNode("out3", "Mul", {"x", "ArithmeticOptimizer/ConvertPow__inner_out3"},
|
||||
{}, &want)
|
||||
->set_device("/device:CPU:0");
|
||||
AddNode("out1", "Identity", {"x"}, {}, &want);
|
||||
AddNode("out.5", "Sqrt", {"x"}, {}, &want);
|
||||
AddNode("out0", "Const", {AsControlDependency("x")}, {}, &want);
|
||||
|
@ -147,6 +147,67 @@ BM_BINARY_SCALAR(sycl, DivNoNan);
|
||||
|
||||
#undef BM_BINARY_SCALAR
|
||||
|
||||
// Three implementations of x^3.
|
||||
Graph* CubeWithPow3(int num) {
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
Tensor lhs(DT_FLOAT, TensorShape({64, 64, num / (64 * 64)}));
|
||||
lhs.flat<float>().setRandom();
|
||||
Tensor rhs(DT_FLOAT, TensorShape({}));
|
||||
rhs.flat<float>().setConstant(3);
|
||||
test::graph::Binary(g, "Pow", test::graph::Constant(g, lhs),
|
||||
test::graph::Constant(g, rhs));
|
||||
return g;
|
||||
}
|
||||
|
||||
Graph* CubeWithTwoMuls(int num) {
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
Tensor lhs(DT_FLOAT, TensorShape({64, 64, num / (64 * 64)}));
|
||||
lhs.flat<float>().setRandom();
|
||||
auto* x = test::graph::Constant(g, lhs);
|
||||
auto* inner = test::graph::Binary(g, "Mul", x, x);
|
||||
test::graph::Binary(g, "Mul", x, inner);
|
||||
return g;
|
||||
}
|
||||
|
||||
Graph* CubeWithMulSquare(int num) {
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
Tensor lhs(DT_FLOAT, TensorShape({64, 64, num / (64 * 64)}));
|
||||
lhs.flat<float>().setRandom();
|
||||
auto* x = test::graph::Constant(g, lhs);
|
||||
auto* inner = test::graph::Unary(g, "Square", x);
|
||||
test::graph::Binary(g, "Mul", test::graph::Constant(g, lhs), inner);
|
||||
return g;
|
||||
}
|
||||
|
||||
#define BM_CUBE(DEVICE, Impl) \
|
||||
void BM_##DEVICE##_Cube_##Impl(int iters, int num) { \
|
||||
const int64 tot = static_cast<int64>(iters) * num; \
|
||||
testing::UseRealTime(); \
|
||||
testing::ItemsProcessed(tot); \
|
||||
testing::BytesProcessed(tot * sizeof(float)); \
|
||||
test::Benchmark(#DEVICE, Impl(num)).Run(iters); \
|
||||
} \
|
||||
BENCHMARK(BM_##DEVICE##_Cube_##Impl) \
|
||||
->Arg(1 << 12) /* must >= 4096 */ \
|
||||
->Arg(1 << 16) \
|
||||
->Arg(1 << 20);
|
||||
|
||||
BM_CUBE(cpu, CubeWithPow3);
|
||||
BM_CUBE(cpu, CubeWithTwoMuls);
|
||||
BM_CUBE(cpu, CubeWithMulSquare);
|
||||
#if GOOGLE_CUDA
|
||||
BM_CUBE(gpu, CubeWithPow3);
|
||||
BM_CUBE(gpu, CubeWithTwoMuls);
|
||||
BM_CUBE(gpu, CubeWithMulSquare);
|
||||
#endif // GOOGLE_CUDA
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
BM_CUBE(sycl, CubeWithPow3);
|
||||
BM_CUBE(sycl, CubeWithTwoMuls);
|
||||
BM_CUBE(sycl, CubeWithMulSquare);
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
#undef BM_CUBE
|
||||
|
||||
template <class T>
|
||||
Graph* BiasAdd(int rows, int cols, DataType type) {
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
|
Loading…
Reference in New Issue
Block a user