STT-tensorflow/tensorflow/python/kernel_tests/substr_op_test.py
Gaurav Jain 24f578cd66 Add @run_deprecated_v1 annotation to tests failing in v2
PiperOrigin-RevId: 223422907
2018-11-29 15:43:25 -08:00

498 lines
20 KiB
Python

# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Substr op from string_ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import test_util
from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test
class SubstrOpTest(test.TestCase, parameterized.TestCase):
@parameterized.parameters(
(np.int32, 1, "BYTE"),
(np.int64, 1, "BYTE"),
(np.int32, -4, "BYTE"),
(np.int64, -4, "BYTE"),
(np.int32, 1, "UTF8_CHAR"),
(np.int64, 1, "UTF8_CHAR"),
(np.int32, -4, "UTF8_CHAR"),
(np.int64, -4, "UTF8_CHAR"),
)
def testScalarString(self, dtype, pos, unit):
test_string = {
"BYTE": b"Hello",
"UTF8_CHAR": u"He\xc3\xc3\U0001f604".encode("utf-8"),
}[unit]
expected_value = {
"BYTE": b"ell",
"UTF8_CHAR": u"e\xc3\xc3".encode("utf-8"),
}[unit]
position = np.array(pos, dtype)
length = np.array(3, dtype)
substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = self.evaluate(substr_op)
self.assertAllEqual(substr, expected_value)
@parameterized.parameters(
(np.int32, "BYTE"),
(np.int64, "BYTE"),
(np.int32, "UTF8_CHAR"),
(np.int64, "UTF8_CHAR"),
)
def testScalarString_EdgeCases(self, dtype, unit):
# Empty string
test_string = {
"BYTE": b"",
"UTF8_CHAR": u"".encode("utf-8"),
}[unit]
expected_value = b""
position = np.array(0, dtype)
length = np.array(3, dtype)
substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = self.evaluate(substr_op)
self.assertAllEqual(substr, expected_value)
# Full string
test_string = {
"BYTE": b"Hello",
"UTF8_CHAR": u"H\xc3ll\U0001f604".encode("utf-8"),
}[unit]
position = np.array(0, dtype)
length = np.array(5, dtype)
substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = self.evaluate(substr_op)
self.assertAllEqual(substr, test_string)
# Full string (Negative)
test_string = {
"BYTE": b"Hello",
"UTF8_CHAR": u"H\xc3ll\U0001f604".encode("utf-8"),
}[unit]
position = np.array(-5, dtype)
length = np.array(5, dtype)
substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = self.evaluate(substr_op)
self.assertAllEqual(substr, test_string)
# Length is larger in magnitude than a negative position
test_string = {
"BYTE": b"Hello",
"UTF8_CHAR": u"H\xc3ll\U0001f604".encode("utf-8"),
}[unit]
expected_string = {
"BYTE": b"ello",
"UTF8_CHAR": u"\xc3ll\U0001f604".encode("utf-8"),
}[unit]
position = np.array(-4, dtype)
length = np.array(5, dtype)
substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = self.evaluate(substr_op)
self.assertAllEqual(substr, expected_string)
@parameterized.parameters(
(np.int32, 1, "BYTE"),
(np.int64, 1, "BYTE"),
(np.int32, -4, "BYTE"),
(np.int64, -4, "BYTE"),
(np.int32, 1, "UTF8_CHAR"),
(np.int64, 1, "UTF8_CHAR"),
(np.int32, -4, "UTF8_CHAR"),
(np.int64, -4, "UTF8_CHAR"),
)
def testVectorStrings(self, dtype, pos, unit):
test_string = {
"BYTE": [b"Hello", b"World"],
"UTF8_CHAR": [x.encode("utf-8") for x in [u"H\xc3llo",
u"W\U0001f604rld"]],
}[unit]
expected_value = {
"BYTE": [b"ell", b"orl"],
"UTF8_CHAR": [x.encode("utf-8") for x in [u"\xc3ll", u"\U0001f604rl"]],
}[unit]
position = np.array(pos, dtype)
length = np.array(3, dtype)
substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = self.evaluate(substr_op)
self.assertAllEqual(substr, expected_value)
@parameterized.parameters(
(np.int32, "BYTE"),
(np.int64, "BYTE"),
(np.int32, "UTF8_CHAR"),
(np.int64, "UTF8_CHAR"),
)
def testMatrixStrings(self, dtype, unit):
test_string = {
"BYTE": [[b"ten", b"eleven", b"twelve"],
[b"thirteen", b"fourteen", b"fifteen"],
[b"sixteen", b"seventeen", b"eighteen"]],
"UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227n",
u"\xc6\u053c\u025bv\u025bn",
u"tw\u0c1dlv\u025b"]],
[x.encode("utf-8") for x in [u"He\xc3\xc3o",
u"W\U0001f604rld",
u"d\xfcd\xea"]]],
}[unit]
position = np.array(1, dtype)
length = np.array(4, dtype)
expected_value = {
"BYTE": [[b"en", b"leve", b"welv"], [b"hirt", b"ourt", b"ifte"],
[b"ixte", b"even", b"ight"]],
"UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d227n",
u"\u053c\u025bv\u025b",
u"w\u0c1dlv"]],
[x.encode("utf-8") for x in [u"e\xc3\xc3o",
u"\U0001f604rld",
u"\xfcd\xea"]]],
}[unit]
substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = self.evaluate(substr_op)
self.assertAllEqual(substr, expected_value)
position = np.array(-3, dtype)
length = np.array(2, dtype)
expected_value = {
"BYTE": [[b"te", b"ve", b"lv"], [b"ee", b"ee", b"ee"],
[b"ee", b"ee", b"ee"]],
"UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227",
u"v\u025b", u"lv"]],
[x.encode("utf-8") for x in [u"\xc3\xc3", u"rl",
u"\xfcd"]]],
}[unit]
substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = self.evaluate(substr_op)
self.assertAllEqual(substr, expected_value)
@parameterized.parameters(
(np.int32, "BYTE"),
(np.int64, "BYTE"),
(np.int32, "UTF8_CHAR"),
(np.int64, "UTF8_CHAR"),
)
def testElementWisePosLen(self, dtype, unit):
test_string = {
"BYTE": [[b"ten", b"eleven", b"twelve"],
[b"thirteen", b"fourteen", b"fifteen"],
[b"sixteen", b"seventeen", b"eighteen"]],
"UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227n",
u"\xc6\u053c\u025bv\u025bn",
u"tw\u0c1dlv\u025b"]],
[x.encode("utf-8") for x in [u"He\xc3\xc3o",
u"W\U0001f604rld",
u"d\xfcd\xea"]],
[x.encode("utf-8") for x in [u"sixt\xea\xean",
u"se\U00010299enteen",
u"ei\U0001e920h\x86een"]]],
}[unit]
position = np.array([[1, -4, 3], [1, 2, -4], [-5, 2, 3]], dtype)
length = np.array([[2, 2, 4], [4, 3, 2], [5, 5, 5]], dtype)
expected_value = {
"BYTE": [[b"en", b"ev", b"lve"], [b"hirt", b"urt", b"te"],
[b"xteen", b"vente", b"hteen"]],
"UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d227n",
u"\u025bv",
u"lv\u025b"]],
[x.encode("utf-8") for x in [u"e\xc3\xc3o",
u"rld",
u"d\xfc"]],
[x.encode("utf-8") for x in [u"xt\xea\xean",
u"\U00010299ente",
u"h\x86een"]]],
}[unit]
substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = self.evaluate(substr_op)
self.assertAllEqual(substr, expected_value)
@parameterized.parameters(
(np.int32, "BYTE"),
(np.int64, "BYTE"),
(np.int32, "UTF8_CHAR"),
(np.int64, "UTF8_CHAR"),
)
def testBroadcast(self, dtype, unit):
# Broadcast pos/len onto input string
test_string = {
"BYTE": [[b"ten", b"eleven", b"twelve"],
[b"thirteen", b"fourteen", b"fifteen"],
[b"sixteen", b"seventeen", b"eighteen"],
[b"nineteen", b"twenty", b"twentyone"]],
"UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227n",
u"\xc6\u053c\u025bv\u025bn",
u"tw\u0c1dlv\u025b"]],
[x.encode("utf-8") for x in [u"th\xcdrt\xea\xean",
u"f\U0001f604urt\xea\xean",
u"f\xcd\ua09ctee\ua0e4"]],
[x.encode("utf-8") for x in [u"s\xcdxt\xea\xean",
u"se\U00010299enteen",
u"ei\U0001e920h\x86een"]],
[x.encode("utf-8") for x in [u"nineteen",
u"twenty",
u"twentyone"]]],
}[unit]
position = np.array([1, -4, 3], dtype)
length = np.array([1, 2, 3], dtype)
expected_value = {
"BYTE": [[b"e", b"ev", b"lve"], [b"h", b"te", b"tee"],
[b"i", b"te", b"hte"], [b"i", b"en", b"nty"]],
"UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d227",
u"\u025bv", u"lv\u025b"]],
[x.encode("utf-8") for x in [u"h", u"t\xea", u"tee"]],
[x.encode("utf-8") for x in [u"\xcd", u"te", u"h\x86e"]],
[x.encode("utf-8") for x in [u"i", u"en", u"nty"]]],
}[unit]
substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = self.evaluate(substr_op)
self.assertAllEqual(substr, expected_value)
# Broadcast input string onto pos/len
test_string = {
"BYTE": [b"thirteen", b"fourteen", b"fifteen"],
"UTF8_CHAR": [x.encode("utf-8") for x in [u"th\xcdrt\xea\xean",
u"f\U0001f604urt\xea\xean",
u"f\xcd\ua09ctee\ua0e4"]],
}[unit]
position = np.array([[1, -2, 3], [-3, 2, 1], [5, 5, -5]], dtype)
length = np.array([[3, 2, 1], [1, 2, 3], [2, 2, 2]], dtype)
expected_value = {
"BYTE": [[b"hir", b"en", b"t"], [b"e", b"ur", b"ift"],
[b"ee", b"ee", b"ft"]],
"UTF8_CHAR": [[x.encode("utf-8") for x in [u"h\xcdr", u"\xean", u"t"]],
[x.encode("utf-8") for x in [u"\xea", u"ur",
u"\xcd\ua09ct"]],
[x.encode("utf-8") for x in [u"\xea\xea", u"\xea\xea",
u"\ua09ct"]]],
}[unit]
substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = self.evaluate(substr_op)
self.assertAllEqual(substr, expected_value)
# Test 1D broadcast
test_string = {
"BYTE": b"thirteen",
"UTF8_CHAR": u"th\xcdrt\xea\xean".encode("utf-8"),
}[unit]
position = np.array([1, -4, 7], dtype)
length = np.array([3, 2, 1], dtype)
expected_value = {
"BYTE": [b"hir", b"te", b"n"],
"UTF8_CHAR": [x.encode("utf-8") for x in [u"h\xcdr", u"t\xea", u"n"]],
}[unit]
substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = self.evaluate(substr_op)
self.assertAllEqual(substr, expected_value)
@parameterized.parameters(
(np.int32, "BYTE"),
(np.int64, "BYTE"),
(np.int32, "UTF8_CHAR"),
(np.int64, "UTF8_CHAR"),
)
@test_util.run_deprecated_v1
def testBadBroadcast(self, dtype, unit):
test_string = [[b"ten", b"eleven", b"twelve"],
[b"thirteen", b"fourteen", b"fifteen"],
[b"sixteen", b"seventeen", b"eighteen"]]
position = np.array([1, 2, -3, 4], dtype)
length = np.array([1, 2, 3, 4], dtype)
with self.assertRaises(ValueError):
string_ops.substr(test_string, position, length, unit=unit)
@parameterized.parameters(
(np.int32, 6, "BYTE"),
(np.int64, 6, "BYTE"),
(np.int32, -6, "BYTE"),
(np.int64, -6, "BYTE"),
(np.int32, 6, "UTF8_CHAR"),
(np.int64, 6, "UTF8_CHAR"),
(np.int32, -6, "UTF8_CHAR"),
(np.int64, -6, "UTF8_CHAR"),
)
@test_util.run_deprecated_v1
def testOutOfRangeError_Scalar(self, dtype, pos, unit):
# Scalar/Scalar
test_string = {
"BYTE": b"Hello",
"UTF8_CHAR": u"H\xc3ll\U0001f604".encode("utf-8"),
}[unit]
position = np.array(pos, dtype)
length = np.array(3, dtype)
substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
self.evaluate(substr_op)
@parameterized.parameters(
(np.int32, 4, "BYTE"),
(np.int64, 4, "BYTE"),
(np.int32, -4, "BYTE"),
(np.int64, -4, "BYTE"),
(np.int32, 4, "UTF8_CHAR"),
(np.int64, 4, "UTF8_CHAR"),
(np.int32, -4, "UTF8_CHAR"),
(np.int64, -4, "UTF8_CHAR"),
)
@test_util.run_deprecated_v1
def testOutOfRangeError_VectorScalar(self, dtype, pos, unit):
# Vector/Scalar
test_string = {
"BYTE": [b"good", b"good", b"bad", b"good"],
"UTF8_CHAR": [x.encode("utf-8") for x in [u"g\xc3\xc3d", u"b\xc3d",
u"g\xc3\xc3d"]],
}[unit]
position = np.array(pos, dtype)
length = np.array(1, dtype)
substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
self.evaluate(substr_op)
@parameterized.parameters(
(np.int32, "BYTE"),
(np.int64, "BYTE"),
(np.int32, "UTF8_CHAR"),
(np.int64, "UTF8_CHAR"),
)
@test_util.run_deprecated_v1
def testOutOfRangeError_MatrixMatrix(self, dtype, unit):
# Matrix/Matrix
test_string = {
"BYTE": [[b"good", b"good", b"good"], [b"good", b"good", b"bad"],
[b"good", b"good", b"good"]],
"UTF8_CHAR": [[x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d",
u"g\xc3\xc3d"]],
[x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d",
u"b\xc3d"]],
[x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d",
u"g\xc3\xc3d"]]],
}[unit]
position = np.array([[1, 2, 3], [1, 2, 4], [1, 2, 3]], dtype)
length = np.array([[3, 2, 1], [1, 2, 3], [2, 2, 2]], dtype)
substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
self.evaluate(substr_op)
# Matrix/Matrix (with negative)
position = np.array([[1, 2, -3], [1, 2, -4], [1, 2, -3]], dtype)
length = np.array([[3, 2, 1], [1, 2, 3], [2, 2, 2]], dtype)
substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
self.evaluate(substr_op)
@parameterized.parameters(
(np.int32, "BYTE"),
(np.int64, "BYTE"),
(np.int32, "UTF8_CHAR"),
(np.int64, "UTF8_CHAR"),
)
@test_util.run_deprecated_v1
def testOutOfRangeError_Broadcast(self, dtype, unit):
# Broadcast
test_string = {
"BYTE": [[b"good", b"good", b"good"], [b"good", b"good", b"bad"]],
"UTF8_CHAR": [[x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d",
u"g\xc3\xc3d"]],
[x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d",
u"b\xc3d"]]],
}[unit]
position = np.array([1, 2, 4], dtype)
length = np.array([1, 2, 3], dtype)
substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
self.evaluate(substr_op)
# Broadcast (with negative)
position = np.array([-1, -2, -4], dtype)
length = np.array([1, 2, 3], dtype)
substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
self.evaluate(substr_op)
@parameterized.parameters(
(np.int32, "BYTE"),
(np.int64, "BYTE"),
(np.int32, "UTF8_CHAR"),
(np.int64, "UTF8_CHAR"),
)
@test_util.run_deprecated_v1
def testMismatchPosLenShapes(self, dtype, unit):
test_string = {
"BYTE": [[b"ten", b"eleven", b"twelve"],
[b"thirteen", b"fourteen", b"fifteen"],
[b"sixteen", b"seventeen", b"eighteen"]],
"UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227n",
u"\xc6\u053c\u025bv\u025bn",
u"tw\u0c1dlv\u025b"]],
[x.encode("utf-8") for x in [u"th\xcdrt\xea\xean",
u"f\U0001f604urt\xea\xean",
u"f\xcd\ua09ctee\ua0e4"]],
[x.encode("utf-8") for x in [u"s\xcdxt\xea\xean",
u"se\U00010299enteen",
u"ei\U0001e920h\x86een"]]],
}[unit]
position = np.array([[1, 2, 3]], dtype)
length = np.array([2, 3, 4], dtype)
# Should fail: position/length have different rank
with self.assertRaises(ValueError):
string_ops.substr(test_string, position, length)
position = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]], dtype)
length = np.array([[2, 3, 4]], dtype)
# Should fail: position/length have different dimensionality
with self.assertRaises(ValueError):
string_ops.substr(test_string, position, length)
@test_util.run_deprecated_v1
def testWrongDtype(self):
with self.cached_session():
with self.assertRaises(TypeError):
string_ops.substr(b"test", 3.0, 1)
with self.assertRaises(TypeError):
string_ops.substr(b"test", 3, 1.0)
@test_util.run_deprecated_v1
def testInvalidUnit(self):
with self.cached_session():
with self.assertRaises(ValueError):
string_ops.substr(b"test", 3, 1, unit="UTF8")
if __name__ == "__main__":
test.main()