diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 77496fe7960..cb74d135418 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -1891,6 +1891,7 @@ absl::flat_hash_set GetKnownXLAWhitelistOp() { "DynamicStitch", "Einsum", "EmptyTensorList", + "EnsureShape", "ExtractImagePatches", "Igamma", "IgammaGradA", diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 5325addc8df..f830e4c1967 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -1846,3 +1846,20 @@ tf_xla_py_test( "@absl_py//absl/testing:parameterized", ], ) + +tf_xla_py_test( + name = "ensure_shape_op_test", + size = "medium", + srcs = ["ensure_shape_op_test.py"], + python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "optonly", + ], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:platform_test", + ], +) diff --git a/tensorflow/compiler/tests/ensure_shape_op_test.py b/tensorflow/compiler/tests/ensure_shape_op_test.py new file mode 100644 index 00000000000..95de5a9c49b --- /dev/null +++ b/tensorflow/compiler/tests/ensure_shape_op_test.py @@ -0,0 +1,51 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for ensure_shape_op.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors_impl +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.platform import test + + +class EnsureShapeOpTest(xla_test.XLATestCase): + + def testEnsureShape(self): + with self.session() as sess: + p = array_ops.placeholder(dtypes.int32) + with self.test_scope(): + op = check_ops.ensure_shape(p, (None, 3)) + expected_out = [[0, 1, 2], [3, 4, 5], [6, 7, 8]] + self.assertAllEqual(expected_out, + sess.run(op, {p: [[0, 1, 2], [3, 4, 5], [6, 7, 8]]})) + + def testInvalidEnsureShape(self): + with self.session() as sess: + p = array_ops.placeholder(dtypes.int32) + with self.test_scope(): + op = check_ops.ensure_shape(p, (None, 3, 3)) + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + "is not compatible with expected shape"): + sess.run(op, {p: [[0, 1, 2], [3, 4, 5], [6, 7, 8]]}) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index dbb420b14fd..4780bd7455e 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -39,6 +39,7 @@ tf_kernel_library( "elu_op.cc", "elu_op.h", "empty_op.cc", + "ensure_shape_op.cc", "extract_image_patches_op.cc", "fake_param_op.cc", "fake_quantize_ops.cc", diff --git a/tensorflow/compiler/tf2xla/kernels/ensure_shape_op.cc b/tensorflow/compiler/tf2xla/kernels/ensure_shape_op.cc new file mode 100644 index 00000000000..8221327d36f --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/ensure_shape_op.cc @@ -0,0 +1,59 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific ensure_shape Op. + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { +namespace { + +class EnsureShapeOp : public XlaOpKernel { + public: + explicit EnsureShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &expected_shape_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape shape = ctx->InputShape(0); + + // valiate shape + OP_REQUIRES( + ctx, expected_shape_.IsCompatibleWith(shape), + errors::InvalidArgument("Shape of tensor ", this->def().input(0), " ", + shape.DebugString(), + " is not compatible with expected shape ", + expected_shape_.DebugString(), ".")); + + // If shape matches, outputs the tensor. + ctx->SetOutput(0, ctx->Input(0)); + } + + private: + PartialTensorShape expected_shape_; +}; + +REGISTER_XLA_OP(Name("EnsureShape"), EnsureShapeOp); + +} // namespace +} // namespace tensorflow