diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index cb6f77efd1a..273460050fc 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -190,6 +190,13 @@ bool GetElementUnexhaustive(const Tensor& t, int i, const std::set& dtypes, } } +bool NodeIsOnCpu(const NodeDef& node) { + string task; + string device; + return DeviceNameUtils::SplitDeviceName(node.device(), &task, &device) && + absl::StrContains(device, DEVICE_CPU); +} + // Graph optimizer context extension specific to ArithmeticOptimizer. struct ArithmeticOptimizerContext { explicit ArithmeticOptimizerContext(SetVector* nodes_to_simplify) @@ -2361,13 +2368,7 @@ class ReplaceMulWithSquare : public ArithmeticOptimizerStage { const DataType type = GetDataTypeFromAttr(*node, "T"); bool is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128); - string task; - string device; - bool is_on_cpu = - DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) && - absl::StrContains(device, DEVICE_CPU); - - if (!is_complex || is_on_cpu) { + if (!is_complex || NodeIsOnCpu(*node)) { NodeDef* new_square_node = AddCopyNode(optimized_node_name, node); new_square_node->set_op("Square"); for (int i = 1; i < new_square_node->input_size(); ++i) { @@ -2528,6 +2529,30 @@ class ConvertPowStage : public ArithmeticOptimizerStage { node->set_input(1, AsControlDependency(y->name())); AddToOptimizationQueue(node); AddToOptimizationQueue(y); + } else if (curr == complex128(3, 0)) { + // TODO(courbet): Use 'Cube' when it's added to TF ops. + if (NodeIsOnCpu(*node)) { + // We create an inner square node: inner_square = square(x) + const NodeScopeAndName scope_and_name = + ParseNodeScopeAndName(node->name()); + const string inner_square_name = + OptimizedNodeName(scope_and_name, "_inner"); + NodeDef* inner_square_node = ctx().node_map->GetNode(inner_square_name); + if (inner_square_node == nullptr) { + inner_square_node = AddCopyNode(inner_square_name, node); + inner_square_node->set_op("Square"); + inner_square_node->mutable_input()->RemoveLast(); + } + ctx().node_map->AddOutput(x->name(), inner_square_node->name()); + // We modify `node`: node = mul(x, inner_square); + node->set_op("Mul"); + node->set_input(1, inner_square_node->name()); + node->add_input(AsControlDependency(y->name())); + + AddToOptimizationQueue(node); + AddToOptimizationQueue(inner_square_node); + AddToOptimizationQueue(y); + } } else if (curr == complex128(1, 0) && ShapesSymbolicallyEqual(value_props.shape(), output_shape)) { // Pow could be used to broadcast, so make sure the shapes of the two @@ -2985,17 +3010,6 @@ class UnaryOpsComposition : public ArithmeticOptimizerStage { DrivesControlDependency(node)); } - // UnaryOpsComposition is defined only for CPU. - bool NodeIsOnCpu(const NodeDef& node) const { - using absl::StartsWith; - - string task; - string device; - - return DeviceNameUtils::SplitDeviceName(node.device(), &task, &device) && - StartsWith(device, DEVICE_CPU); - } - bool NodeIsAlreadyFused(const NodeDef& node) const { return fused_nodes_.count(node.name()) > 0; } diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index d9ce9f66b7a..ae3da034212 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -2728,6 +2728,7 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2}); auto y2 = ops::Const(s.WithOpName("y2"), {2.0f, 2.0f}, {1, 2}); + auto y3 = ops::Const(s.WithOpName("y3"), {3.0f, 3.0f}, {1, 2}); auto y1 = ops::Const(s.WithOpName("y1"), {1.0f, 1.0f}, {1, 2}); auto yPoint5 = ops::Const(s.WithOpName("y.5"), {0.5f, 0.5f}, {1, 2}); auto y0 = ops::Const(s.WithOpName("y0"), {0.0f, 0.0f}, {1, 2}); @@ -2738,6 +2739,8 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) { auto ones = ops::Const(s.WithOpName("ones"), {1.0f, 1.0f, 1.0f}, {1, 3}); auto zeros = ops::Const(s.WithOpName("zeros"), {0.0f, 0.0f, 0.0f}, {1, 3}); Output out2 = ops::Pow(s.WithOpName("out2"), x, y2); + Output out3 = + ops::Pow(s.WithOpName("out3").WithDevice("/device:CPU:0"), x, y3); Output out1 = ops::Pow(s.WithOpName("out1"), x, y1); Output outPoint5 = ops::Pow(s.WithOpName("out.5"), x, yPoint5); Output out0 = ops::Pow(s.WithOpName("out0"), x, y0); @@ -2748,18 +2751,18 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) { Output out_bcast2 = ops::Pow(s.WithOpName("out_bcast2"), z, zeros); GrapplerItem item; - item.fetch = {"out2", "out1", "out.5", "out0", "out_.5", - "out_1", "out", "out_bcast1", "out_bcast2"}; + item.fetch = {"out2", "out3", "out1", "out.5", "out0", + "out_.5", "out_1", "out", "out_bcast1", "out_bcast2"}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); auto tensors_expected = EvaluateNodes(item.graph, item.fetch); - ASSERT_EQ(tensors_expected.size(), 9); + ASSERT_EQ(tensors_expected.size(), 10); GraphDef got; ArithmeticOptimizer optimizer; EnableOnlyConvertPow(&optimizer); OptimizeAndPrune(&optimizer, &item, &got); auto tensors = EvaluateNodes(got, item.fetch); - ASSERT_EQ(tensors.size(), 9); + ASSERT_EQ(tensors.size(), 10); for (int i = 0; i < tensors.size(); ++i) { EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements()); @@ -2773,6 +2776,12 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) { AddNode("ones", "Const", {}, {}, &want); AddNode("zeros", "Const", {}, {}, &want); AddNode("out2", "Square", {"x"}, {}, &want); + AddNode("ArithmeticOptimizer/ConvertPow__inner_out3", "Square", {"x"}, {}, + &want) + ->set_device("/device:CPU:0"); + AddNode("out3", "Mul", {"x", "ArithmeticOptimizer/ConvertPow__inner_out3"}, + {}, &want) + ->set_device("/device:CPU:0"); AddNode("out1", "Identity", {"x"}, {}, &want); AddNode("out.5", "Sqrt", {"x"}, {}, &want); AddNode("out0", "Const", {AsControlDependency("x")}, {}, &want); diff --git a/tensorflow/core/kernels/cwise_ops_test.cc b/tensorflow/core/kernels/cwise_ops_test.cc index d6ce0f1cfa5..739ccf7730a 100644 --- a/tensorflow/core/kernels/cwise_ops_test.cc +++ b/tensorflow/core/kernels/cwise_ops_test.cc @@ -147,6 +147,67 @@ BM_BINARY_SCALAR(sycl, DivNoNan); #undef BM_BINARY_SCALAR +// Three implementations of x^3. +Graph* CubeWithPow3(int num) { + Graph* g = new Graph(OpRegistry::Global()); + Tensor lhs(DT_FLOAT, TensorShape({64, 64, num / (64 * 64)})); + lhs.flat().setRandom(); + Tensor rhs(DT_FLOAT, TensorShape({})); + rhs.flat().setConstant(3); + test::graph::Binary(g, "Pow", test::graph::Constant(g, lhs), + test::graph::Constant(g, rhs)); + return g; +} + +Graph* CubeWithTwoMuls(int num) { + Graph* g = new Graph(OpRegistry::Global()); + Tensor lhs(DT_FLOAT, TensorShape({64, 64, num / (64 * 64)})); + lhs.flat().setRandom(); + auto* x = test::graph::Constant(g, lhs); + auto* inner = test::graph::Binary(g, "Mul", x, x); + test::graph::Binary(g, "Mul", x, inner); + return g; +} + +Graph* CubeWithMulSquare(int num) { + Graph* g = new Graph(OpRegistry::Global()); + Tensor lhs(DT_FLOAT, TensorShape({64, 64, num / (64 * 64)})); + lhs.flat().setRandom(); + auto* x = test::graph::Constant(g, lhs); + auto* inner = test::graph::Unary(g, "Square", x); + test::graph::Binary(g, "Mul", test::graph::Constant(g, lhs), inner); + return g; +} + +#define BM_CUBE(DEVICE, Impl) \ + void BM_##DEVICE##_Cube_##Impl(int iters, int num) { \ + const int64 tot = static_cast(iters) * num; \ + testing::UseRealTime(); \ + testing::ItemsProcessed(tot); \ + testing::BytesProcessed(tot * sizeof(float)); \ + test::Benchmark(#DEVICE, Impl(num)).Run(iters); \ + } \ + BENCHMARK(BM_##DEVICE##_Cube_##Impl) \ + ->Arg(1 << 12) /* must >= 4096 */ \ + ->Arg(1 << 16) \ + ->Arg(1 << 20); + +BM_CUBE(cpu, CubeWithPow3); +BM_CUBE(cpu, CubeWithTwoMuls); +BM_CUBE(cpu, CubeWithMulSquare); +#if GOOGLE_CUDA +BM_CUBE(gpu, CubeWithPow3); +BM_CUBE(gpu, CubeWithTwoMuls); +BM_CUBE(gpu, CubeWithMulSquare); +#endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +BM_CUBE(sycl, CubeWithPow3); +BM_CUBE(sycl, CubeWithTwoMuls); +BM_CUBE(sycl, CubeWithMulSquare); +#endif // TENSORFLOW_USE_SYCL + +#undef BM_CUBE + template Graph* BiasAdd(int rows, int cols, DataType type) { Graph* g = new Graph(OpRegistry::Global());