Merge pull request #38544 from yongtang:34399-ensure_shape

PiperOrigin-RevId: 306652585
Change-Id: I53312a7f02ee07893816b1f358667e324aa879a6
This commit is contained in:
TensorFlower Gardener 2020-04-15 08:57:07 -07:00
commit c16a5caf79
5 changed files with 129 additions and 0 deletions

View File

@ -1891,6 +1891,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
"DynamicStitch",
"Einsum",
"EmptyTensorList",
"EnsureShape",
"ExtractImagePatches",
"Igamma",
"IgammaGradA",

View File

@ -1846,3 +1846,20 @@ tf_xla_py_test(
"@absl_py//absl/testing:parameterized",
],
)
tf_xla_py_test(
name = "ensure_shape_op_test",
size = "medium",
srcs = ["ensure_shape_op_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
"optonly",
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
],
)

View File

@ -0,0 +1,51 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ensure_shape_op."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.platform import test
class EnsureShapeOpTest(xla_test.XLATestCase):
def testEnsureShape(self):
with self.session() as sess:
p = array_ops.placeholder(dtypes.int32)
with self.test_scope():
op = check_ops.ensure_shape(p, (None, 3))
expected_out = [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
self.assertAllEqual(expected_out,
sess.run(op, {p: [[0, 1, 2], [3, 4, 5], [6, 7, 8]]}))
def testInvalidEnsureShape(self):
with self.session() as sess:
p = array_ops.placeholder(dtypes.int32)
with self.test_scope():
op = check_ops.ensure_shape(p, (None, 3, 3))
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"is not compatible with expected shape"):
sess.run(op, {p: [[0, 1, 2], [3, 4, 5], [6, 7, 8]]})
if __name__ == "__main__":
test.main()

View File

@ -39,6 +39,7 @@ tf_kernel_library(
"elu_op.cc",
"elu_op.h",
"empty_op.cc",
"ensure_shape_op.cc",
"extract_image_patches_op.cc",
"fake_param_op.cc",
"fake_quantize_ops.cc",

View File

@ -0,0 +1,59 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// XLA-specific ensure_shape Op.
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
namespace tensorflow {
namespace {
class EnsureShapeOp : public XlaOpKernel {
public:
explicit EnsureShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &expected_shape_));
}
void Compile(XlaOpKernelContext* ctx) override {
const TensorShape shape = ctx->InputShape(0);
// valiate shape
OP_REQUIRES(
ctx, expected_shape_.IsCompatibleWith(shape),
errors::InvalidArgument("Shape of tensor ", this->def().input(0), " ",
shape.DebugString(),
" is not compatible with expected shape ",
expected_shape_.DebugString(), "."));
// If shape matches, outputs the tensor.
ctx->SetOutput(0, ctx->Input(0));
}
private:
PartialTensorShape expected_shape_;
};
REGISTER_XLA_OP(Name("EnsureShape"), EnsureShapeOp);
} // namespace
} // namespace tensorflow