From da3c1f4b01bc417d50adc2d0979b603baf859e44 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Tue, 14 Apr 2020 16:15:29 +0000 Subject: [PATCH 1/6] Add XLA kernel for tf.ensure_shape This PR is part of the effort to fix 34363 and a prerequisite for PR 34399 which will resolve 34363. In order for the PR 34399 to pass all tests, an XLA kernel for tf.ensure_shape need to be added. Once this PR is merged, PR 34399 will be re-opened to address 34363. Signed-off-by: Yong Tang --- tensorflow/compiler/tf2xla/kernels/BUILD | 1 + .../tf2xla/kernels/ensure_shape_op.cc | 54 +++++++++++++++++++ 2 files changed, 55 insertions(+) create mode 100644 tensorflow/compiler/tf2xla/kernels/ensure_shape_op.cc diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index dbb420b14fd..410edcd634d 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -38,6 +38,7 @@ tf_kernel_library( "einsum_op.cc", "elu_op.cc", "elu_op.h", + "ensure_shape_op.cc", "empty_op.cc", "extract_image_patches_op.cc", "fake_param_op.cc", diff --git a/tensorflow/compiler/tf2xla/kernels/ensure_shape_op.cc b/tensorflow/compiler/tf2xla/kernels/ensure_shape_op.cc new file mode 100644 index 00000000000..62f53ab83c0 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/ensure_shape_op.cc @@ -0,0 +1,54 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific ensure_shape Op. + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { +namespace { + +class EnsureShapeOp : public XlaOpKernel { + public: + explicit EnsureShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &expected_shape_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape shape = ctx->InputShape(0); + + // valiate shape + OP_REQUIRES(ctx, expected_shape_.IsCompatibleWith(shape), + errors::InvalidArgument("Shape of tensor ", this->def().input(0), " ", shape.DebugString(), " is not compatible with expected shape ", expected_shape_.DebugString(), ".")); + + // If shape matches, outputs the tensor. + ctx->SetOutput(0, ctx->Input(0)); + } + private: + PartialTensorShape expected_shape_; +}; + +REGISTER_XLA_OP(Name("EnsureShape"), EnsureShapeOp); + +} // namespace +} // namespace tensorflow From 2ef79b9c919677d0598d0ca05abb50656ed50771 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Tue, 14 Apr 2020 16:35:30 +0000 Subject: [PATCH 2/6] Sanitize with clang-format Signed-off-by: Yong Tang --- tensorflow/compiler/tf2xla/kernels/ensure_shape_op.cc | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/tf2xla/kernels/ensure_shape_op.cc b/tensorflow/compiler/tf2xla/kernels/ensure_shape_op.cc index 62f53ab83c0..8221327d36f 100644 --- a/tensorflow/compiler/tf2xla/kernels/ensure_shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/ensure_shape_op.cc @@ -38,12 +38,17 @@ class EnsureShapeOp : public XlaOpKernel { const TensorShape shape = ctx->InputShape(0); // valiate shape - OP_REQUIRES(ctx, expected_shape_.IsCompatibleWith(shape), - errors::InvalidArgument("Shape of tensor ", this->def().input(0), " ", shape.DebugString(), " is not compatible with expected shape ", expected_shape_.DebugString(), ".")); + OP_REQUIRES( + ctx, expected_shape_.IsCompatibleWith(shape), + errors::InvalidArgument("Shape of tensor ", this->def().input(0), " ", + shape.DebugString(), + " is not compatible with expected shape ", + expected_shape_.DebugString(), ".")); // If shape matches, outputs the tensor. ctx->SetOutput(0, ctx->Input(0)); } + private: PartialTensorShape expected_shape_; }; From f59cf3d105c81221db7e193712a60a4c7bda10b8 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Tue, 14 Apr 2020 16:21:51 +0000 Subject: [PATCH 3/6] Add test case for XLA kernel of EnsureShape Signed-off-by: Yong Tang --- tensorflow/compiler/tests/BUILD | 17 +++++++ .../compiler/tests/ensure_shape_op_test.py | 50 +++++++++++++++++++ 2 files changed, 67 insertions(+) create mode 100644 tensorflow/compiler/tests/ensure_shape_op_test.py diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index b6f5e11b856..00434974ccc 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -1865,3 +1865,20 @@ tf_xla_py_test( "@absl_py//absl/testing:parameterized", ], ) + +tf_xla_py_test( + name = "ensure_shape_op_test", + size = "medium", + srcs = ["ensure_shape_op_test.py"], + python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "optonly", + ], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:platform_test", + ], +) diff --git a/tensorflow/compiler/tests/ensure_shape_op_test.py b/tensorflow/compiler/tests/ensure_shape_op_test.py new file mode 100644 index 00000000000..397fd399fe0 --- /dev/null +++ b/tensorflow/compiler/tests/ensure_shape_op_test.py @@ -0,0 +1,50 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for ensure_shape_op.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors_impl +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class EnsureShapeOpTest(xla_test.XLATestCase): + + def testEnsureShape(self): + with self.session() as sess: + p = array_ops.placeholder(dtypes.int32) + with self.test_scope(): + op = check_ops.ensure_shape(p, (None, 3)) + expected_out = [[0, 1, 2], [3, 4, 5], [6, 7, 8]] + self.assertAllEqual(expected_out, sess.run(op, {p: [[0, 1, 2], [3, 4, 5], [6, 7, 8]]})) + + def testInvalidEnsureShape(self): + with self.session() as sess: + p = array_ops.placeholder(dtypes.int32) + with self.test_scope(): + op = check_ops.ensure_shape(p, (None, 3, 3)) + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, "is not compatible with expected shape"): + sess.run(op, {p: [[0, 1, 2], [3, 4, 5], [6, 7, 8]]}) + + +if __name__ == "__main__": + test.main() From c4648ef718d718ba94f334628bd9cc3bad824e59 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Tue, 14 Apr 2020 16:42:21 +0000 Subject: [PATCH 4/6] Pylint and bazel buildifier fix to pass Ubuntu Sanity CI test Signed-off-by: Yong Tang --- tensorflow/compiler/tests/ensure_shape_op_test.py | 6 ++++-- tensorflow/compiler/tf2xla/kernels/BUILD | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/tests/ensure_shape_op_test.py b/tensorflow/compiler/tests/ensure_shape_op_test.py index 397fd399fe0..1729ffbae41 100644 --- a/tensorflow/compiler/tests/ensure_shape_op_test.py +++ b/tensorflow/compiler/tests/ensure_shape_op_test.py @@ -35,14 +35,16 @@ class EnsureShapeOpTest(xla_test.XLATestCase): with self.test_scope(): op = check_ops.ensure_shape(p, (None, 3)) expected_out = [[0, 1, 2], [3, 4, 5], [6, 7, 8]] - self.assertAllEqual(expected_out, sess.run(op, {p: [[0, 1, 2], [3, 4, 5], [6, 7, 8]]})) + self.assertAllEqual( + expected_out, sess.run(op, {p: [[0, 1, 2], [3, 4, 5], [6, 7, 8]]})) def testInvalidEnsureShape(self): with self.session() as sess: p = array_ops.placeholder(dtypes.int32) with self.test_scope(): op = check_ops.ensure_shape(p, (None, 3, 3)) - with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, "is not compatible with expected shape"): + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + "is not compatible with expected shape"): sess.run(op, {p: [[0, 1, 2], [3, 4, 5], [6, 7, 8]]}) diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 410edcd634d..4780bd7455e 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -38,8 +38,8 @@ tf_kernel_library( "einsum_op.cc", "elu_op.cc", "elu_op.h", - "ensure_shape_op.cc", "empty_op.cc", + "ensure_shape_op.cc", "extract_image_patches_op.cc", "fake_param_op.cc", "fake_quantize_ops.cc", From 88692052ad3e056a36f3da6d359ef72ffc783beb Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Tue, 14 Apr 2020 18:15:43 +0000 Subject: [PATCH 5/6] Remove unused math_ops import to pass CI test Signed-off-by: Yong Tang --- tensorflow/compiler/tests/ensure_shape_op_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorflow/compiler/tests/ensure_shape_op_test.py b/tensorflow/compiler/tests/ensure_shape_op_test.py index 1729ffbae41..51aa82686fb 100644 --- a/tensorflow/compiler/tests/ensure_shape_op_test.py +++ b/tensorflow/compiler/tests/ensure_shape_op_test.py @@ -23,7 +23,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops -from tensorflow.python.ops import math_ops from tensorflow.python.platform import test From 1e46a8d079ef63bdebf54f77cae1888f871eb673 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Tue, 14 Apr 2020 18:13:11 +0000 Subject: [PATCH 6/6] Add EnsureShape to XLALiteWhitelist to pass the CI test Signed-off-by: Yong Tang --- tensorflow/compiler/jit/mark_for_compilation_pass.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 77496fe7960..cb74d135418 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -1891,6 +1891,7 @@ absl::flat_hash_set GetKnownXLAWhitelistOp() { "DynamicStitch", "Einsum", "EmptyTensorList", + "EnsureShape", "ExtractImagePatches", "Igamma", "IgammaGradA",