Update test_util.TensorFlowTestCase's assertAllEqual() and assertAllClose() methods to support RaggedTensors.

PiperOrigin-RevId: 254878390
This commit is contained in:
Edward Loper 2019-06-24 18:16:15 -07:00 committed by TensorFlower Gardener
parent b6b7d99893
commit b972f7334e
53 changed files with 741 additions and 873 deletions

View File

@ -1,266 +1,266 @@
tensorflow/third_party/systemlibs/nsync.BUILD
tensorflow/third_party/systemlibs/absl_py.absl.flags.BUILD
tensorflow/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD
tensorflow/third_party/systemlibs/build_defs.bzl.tpl
tensorflow/third_party/systemlibs/curl.BUILD
tensorflow/third_party/systemlibs/cython.BUILD
tensorflow/third_party/systemlibs/astor.BUILD
tensorflow/third_party/systemlibs/jsoncpp.BUILD
tensorflow/third_party/systemlibs/png.BUILD
tensorflow/third_party/systemlibs/pcre.BUILD
tensorflow/third_party/systemlibs/grpc.BUILD
tensorflow/third_party/systemlibs/protobuf.BUILD
tensorflow/third_party/systemlibs/double_conversion.BUILD
tensorflow/third_party/systemlibs/six.BUILD
tensorflow/third_party/systemlibs/zlib.BUILD
tensorflow/third_party/systemlibs/lmdb.BUILD
tensorflow/third_party/systemlibs/sqlite.BUILD
tensorflow/third_party/systemlibs/gast.BUILD
tensorflow/third_party/systemlibs/absl_py.BUILD
tensorflow/third_party/systemlibs/boringssl.BUILD
tensorflow/third_party/systemlibs/BUILD.tpl
tensorflow/third_party/systemlibs/BUILD
tensorflow/third_party/systemlibs/termcolor.BUILD
tensorflow/third_party/systemlibs/gif.BUILD
tensorflow/third_party/systemlibs/protobuf.bzl
tensorflow/third_party/systemlibs/snappy.BUILD
tensorflow/third_party/systemlibs/googleapis.BUILD
tensorflow/third_party/systemlibs/google_cloud_cpp.BUILD
tensorflow/third_party/systemlibs/re2.BUILD
tensorflow/third_party/systemlibs/swig.BUILD
tensorflow/third_party/systemlibs/syslibs_configure.bzl
tensorflow/third_party/systemlibs/absl_py.absl.testing.BUILD
tensorflow/third_party/pprof.BUILD
tensorflow/third_party/toolchains/remote/execution.bzl.tpl
tensorflow/third_party/toolchains/remote/BUILD.tpl
tensorflow/third_party/toolchains/remote/BUILD
tensorflow/third_party/toolchains/remote/configure.bzl
tensorflow/third_party/toolchains/cpus/py3/BUILD
tensorflow/third_party/toolchains/cpus/py/BUILD
tensorflow/contrib/tpu/profiler/pip_package/BUILD
tensorflow/contrib/tpu/profiler/pip_package/setup.py
tensorflow/contrib/tpu/profiler/pip_package/README
tensorflow/contrib/tpu/profiler/pip_package/build_pip_package.sh
tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/__init__.py
tensorflow/contrib/mpi/BUILD
tensorflow/stream_executor/build_defs.bzl
tensorflow/python/autograph/core/config.py
tensorflow/tools/ci_build/remote/BUILD
tensorflow/tools/pip_package/README
tensorflow/tools/pip_package/MANIFEST.in
tensorflow/tools/pip_package/simple_console.py
tensorflow/tools/pip_package/build_pip_package.sh
tensorflow/tools/pip_package/check_load_py_test.py
tensorflow/tools/pip_package/pip_smoke_test.py
tensorflow/tools/pip_package/simple_console_for_windows.py
tensorflow/tools/pip_package/setup.py
tensorflow/tools/pip_package/BUILD
tensorflow/tools/lib_package/concat_licenses.sh
tensorflow/tools/lib_package/libtensorflow_test.c
tensorflow/tools/lib_package/LibTensorFlowTest.java
tensorflow/tools/lib_package/BUILD
tensorflow/tools/lib_package/libtensorflow_test.sh
tensorflow/tools/lib_package/README.md
tensorflow/tools/lib_package/libtensorflow_java_test.sh
tensorflow/tools/def_file_filter/def_file_filter_configure.bzl
tensorflow/tools/def_file_filter/BUILD
tensorflow/tools/def_file_filter/BUILD.tpl
tensorflow/tools/def_file_filter/def_file_filter.py.tpl
tensorflow/third_party/mkl/MKL_LICENSE
tensorflow/third_party/mkl/LICENSE
tensorflow/third_party/mkl/BUILD
tensorflow/third_party/mkl/mkl.BUILD
tensorflow/third_party/mkl/build_defs.bzl
tensorflow/third_party/backports_weakref.BUILD
tensorflow/third_party/toolchains/clang6/BUILD
tensorflow/third_party/toolchains/clang6/README.md
tensorflow/third_party/toolchains/clang6/repo.bzl
tensorflow/third_party/toolchains/clang6/CROSSTOOL.tpl
tensorflow/third_party/toolchains/clang6/clang.BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/cc_toolchain_config.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/dummy_toolchain.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.0/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.0/cc_toolchain_config.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/cuda10.0-cudnn7/cuda/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/cuda10.0-cudnn7/cuda/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py3/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/py3/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/cc_toolchain_config.bzl
tensorflow/third_party/toolchains/preconfig/generate/workspace.bzl
tensorflow/third_party/toolchains/preconfig/generate/containers.bzl
tensorflow/third_party/toolchains/preconfig/generate/generate.bzl
tensorflow/third_party/toolchains/preconfig/generate/archives.bzl
tensorflow/third_party/toolchains/preconfig/generate/BUILD
tensorflow/third_party/toolchains/preconfig/centos6/py/BUILD
tensorflow/third_party/toolchains/preconfig/centos6/cuda10.0-cudnn7/cuda/BUILD
tensorflow/third_party/toolchains/preconfig/centos6/cuda10.0-cudnn7/cuda/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/centos6/gcc7-nvcc-cuda10.0/BUILD
tensorflow/third_party/toolchains/preconfig/centos6/gcc7-nvcc-cuda10.0/cc_toolchain_config.bzl
tensorflow/third_party/toolchains/preconfig/centos6/tensorrt5/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/centos6/tensorrt5/BUILD
tensorflow/third_party/toolchains/preconfig/centos6/gcc7/cc_toolchain_config.bzl
tensorflow/third_party/toolchains/preconfig/centos6/gcc7/dummy_toolchain.bzl
tensorflow/third_party/toolchains/preconfig/centos6/gcc7/BUILD
tensorflow/third_party/toolchains/preconfig/centos6/py3/BUILD
tensorflow/third_party/toolchains/preconfig/win_1803/bazel_025/BUILD
tensorflow/third_party/toolchains/preconfig/win_1803/py36/BUILD
tensorflow/third_party/toolchains/preconfig/win_1803/BUILD
tensorflow/third_party/toolchains/cpus/arm/arm_compiler_configure.bzl
tensorflow/third_party/toolchains/cpus/arm/CROSSTOOL.tpl
tensorflow/third_party/toolchains/cpus/arm/BUILD
tensorflow/third_party/toolchains/cpus/py3/BUILD
tensorflow/third_party/toolchains/cpus/py/BUILD
tensorflow/third_party/toolchains/remote/configure.bzl
tensorflow/third_party/toolchains/remote/BUILD.tpl
tensorflow/third_party/toolchains/remote/BUILD
tensorflow/third_party/toolchains/remote/execution.bzl.tpl
tensorflow/third_party/toolchains/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py3/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/cc_toolchain_config.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/dummy_toolchain.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.0/cc_toolchain_config.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.0/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/cuda10.0-cudnn7/cuda/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/cuda10.0-cudnn7/cuda/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5/BUILD
tensorflow/third_party/toolchains/preconfig/win_1803/bazel_025/BUILD
tensorflow/third_party/toolchains/preconfig/win_1803/BUILD
tensorflow/third_party/toolchains/preconfig/win_1803/py36/BUILD
tensorflow/third_party/toolchains/preconfig/generate/containers.bzl
tensorflow/third_party/toolchains/preconfig/generate/workspace.bzl
tensorflow/third_party/toolchains/preconfig/generate/archives.bzl
tensorflow/third_party/toolchains/preconfig/generate/generate.bzl
tensorflow/third_party/toolchains/preconfig/generate/BUILD
tensorflow/third_party/toolchains/preconfig/centos6/py3/BUILD
tensorflow/third_party/toolchains/preconfig/centos6/py/BUILD
tensorflow/third_party/toolchains/preconfig/centos6/gcc7/cc_toolchain_config.bzl
tensorflow/third_party/toolchains/preconfig/centos6/gcc7/BUILD
tensorflow/third_party/toolchains/preconfig/centos6/gcc7/dummy_toolchain.bzl
tensorflow/third_party/toolchains/preconfig/centos6/gcc7-nvcc-cuda10.0/cc_toolchain_config.bzl
tensorflow/third_party/toolchains/preconfig/centos6/gcc7-nvcc-cuda10.0/BUILD
tensorflow/third_party/toolchains/preconfig/centos6/cuda10.0-cudnn7/cuda/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/centos6/cuda10.0-cudnn7/cuda/BUILD
tensorflow/third_party/toolchains/preconfig/centos6/tensorrt5/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/centos6/tensorrt5/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/py3/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/cc_toolchain_config.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/BUILD
tensorflow/third_party/toolchains/clang6/repo.bzl
tensorflow/third_party/toolchains/clang6/CROSSTOOL.tpl
tensorflow/third_party/toolchains/clang6/BUILD
tensorflow/third_party/toolchains/clang6/clang.BUILD
tensorflow/third_party/toolchains/clang6/README.md
tensorflow/third_party/farmhash.BUILD
tensorflow/third_party/git/BUILD.tpl
tensorflow/third_party/git/git_configure.bzl
tensorflow/third_party/git/BUILD
tensorflow/third_party/cub.BUILD
tensorflow/third_party/gpus/cuda_configure.bzl
tensorflow/third_party/gpus/rocm/build_defs.bzl.tpl
tensorflow/third_party/gpus/rocm/BUILD.tpl
tensorflow/third_party/gpus/rocm/BUILD
tensorflow/third_party/gpus/rocm/rocm_config.h.tpl
tensorflow/third_party/gpus/rocm_configure.bzl
tensorflow/third_party/gpus/find_cuda_config.py
tensorflow/third_party/gpus/BUILD
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
tensorflow/third_party/gpus/crosstool/LICENSE
tensorflow/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl
tensorflow/third_party/gpus/crosstool/BUILD.tpl
tensorflow/third_party/gpus/crosstool/BUILD
tensorflow/third_party/gpus/cuda/build_defs.bzl.tpl
tensorflow/third_party/gpus/cuda/LICENSE
tensorflow/third_party/gpus/cuda/BUILD.tpl
tensorflow/third_party/gpus/cuda/BUILD.windows.tpl
tensorflow/third_party/gpus/cuda/cuda_config.h.tpl
tensorflow/third_party/gpus/cuda/BUILD.tpl
tensorflow/third_party/gpus/cuda/BUILD
tensorflow/third_party/gpus/BUILD
tensorflow/third_party/common.bzl
tensorflow/third_party/tflite_mobilenet_quant.BUILD
tensorflow/third_party/linenoise.BUILD
tensorflow/third_party/curl.BUILD
tensorflow/third_party/mkl_dnn/LICENSE
tensorflow/third_party/mkl_dnn/mkldnn.BUILD
tensorflow/third_party/fft2d/LICENSE
tensorflow/third_party/fft2d/fft2d.BUILD
tensorflow/third_party/fft2d/fft2d.h
tensorflow/third_party/fft2d/fft.h
tensorflow/third_party/fft2d/BUILD
tensorflow/third_party/ngraph/LICENSE
tensorflow/third_party/ngraph/build_defs.bzl
tensorflow/third_party/ngraph/tbb.BUILD
tensorflow/third_party/ngraph/ngraph.BUILD
tensorflow/third_party/ngraph/nlohmann_json.BUILD
tensorflow/third_party/ngraph/BUILD
tensorflow/third_party/ngraph/ngraph_tf.BUILD
tensorflow/third_party/ngraph/NGRAPH_LICENSE
tensorflow/third_party/grpc/BUILD
tensorflow/third_party/gpus/cuda/build_defs.bzl.tpl
tensorflow/third_party/gpus/rocm/rocm_config.h.tpl
tensorflow/third_party/gpus/rocm/BUILD
tensorflow/third_party/gpus/rocm/BUILD.tpl
tensorflow/third_party/gpus/rocm/build_defs.bzl.tpl
tensorflow/third_party/gpus/cuda_configure.bzl
tensorflow/third_party/gpus/find_cuda_config.py
tensorflow/third_party/gpus/rocm_configure.bzl
tensorflow/third_party/snappy.BUILD
tensorflow/third_party/cython.BUILD
tensorflow/third_party/icu/udata.patch
tensorflow/third_party/astor.BUILD
tensorflow/third_party/jsoncpp.BUILD
tensorflow/third_party/sycl/crosstool/BUILD
tensorflow/third_party/llvm/llvm.autogenerated.BUILD
tensorflow/third_party/llvm/expand_cmake_vars.py
tensorflow/third_party/llvm/llvm.bzl
tensorflow/third_party/llvm/BUILD
tensorflow/third_party/png.BUILD
tensorflow/third_party/arm_neon_2_x86_sse.BUILD
tensorflow/third_party/codegen.BUILD
tensorflow/third_party/enum34.BUILD
tensorflow/third_party/kafka/config.patch
tensorflow/third_party/kafka/BUILD
tensorflow/third_party/pcre.BUILD
tensorflow/third_party/mpi/BUILD
tensorflow/third_party/mpi/.gitignore
tensorflow/third_party/clang_toolchain/BUILD
tensorflow/third_party/clang_toolchain/download_clang.bzl
tensorflow/third_party/clang_toolchain/cc_configure_clang.bzl
tensorflow/third_party/tflite_ovic_testdata.BUILD
tensorflow/third_party/repo.bzl
tensorflow/third_party/png_fix_rpi.patch
tensorflow/third_party/py/python_configure.bzl
tensorflow/third_party/py/BUILD.tpl
tensorflow/third_party/py/BUILD
tensorflow/third_party/py/numpy/BUILD
tensorflow/third_party/double_conversion.BUILD
tensorflow/third_party/six.BUILD
tensorflow/third_party/zlib.BUILD
tensorflow/third_party/lmdb.BUILD
tensorflow/third_party/nanopb.BUILD
tensorflow/third_party/pybind11.BUILD
tensorflow/third_party/android/android.bzl.tpl
tensorflow/third_party/android/BUILD
tensorflow/third_party/android/android_configure.BUILD.tpl
tensorflow/third_party/android/android_configure.bzl
tensorflow/third_party/tflite_mobilenet_float.BUILD
tensorflow/third_party/sqlite.BUILD
tensorflow/third_party/tensorrt/build_defs.bzl.tpl
tensorflow/third_party/tensorrt/LICENSE
tensorflow/third_party/tensorrt/tensorrt_configure.bzl
tensorflow/third_party/tensorrt/tensorrt/include/tensorrt_config.h.tpl
tensorflow/third_party/tensorrt/BUILD.tpl
tensorflow/third_party/tensorrt/BUILD
tensorflow/third_party/gast.BUILD
tensorflow/third_party/mpi_collectives/BUILD
tensorflow/third_party/libxsmm.BUILD
tensorflow/third_party/eigen.BUILD
tensorflow/third_party/com_google_absl.BUILD
tensorflow/third_party/eigen3/LICENSE
tensorflow/third_party/eigen3/gpu_packet_math.patch
tensorflow/third_party/eigen3/BUILD
tensorflow/third_party/eigen3/unsupported/Eigen/MatrixFunctions
tensorflow/third_party/eigen3/unsupported/Eigen/SpecialFunctions
tensorflow/third_party/farmhash.BUILD
tensorflow/third_party/eigen3/Eigen/Cholesky
tensorflow/third_party/eigen3/Eigen/QR
tensorflow/third_party/eigen3/Eigen/LU
tensorflow/third_party/eigen3/Eigen/Core
tensorflow/third_party/eigen3/Eigen/SVD
tensorflow/third_party/eigen3/Eigen/Eigenvalues
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool
tensorflow/third_party/eigen3/Eigen/QR
tensorflow/third_party/eigen3/Eigen/SVD
tensorflow/third_party/eigen3/Eigen/LU
tensorflow/third_party/eigen3/Eigen/Cholesky
tensorflow/third_party/eigen3/Eigen/Eigenvalues
tensorflow/third_party/eigen3/Eigen/Core
tensorflow/third_party/BUILD
tensorflow/third_party/termcolor.BUILD
tensorflow/third_party/gif.BUILD
tensorflow/third_party/tflite_mobilenet.BUILD
tensorflow/third_party/__init__.py
tensorflow/third_party/mkl/LICENSE
tensorflow/third_party/mkl/build_defs.bzl
tensorflow/third_party/mkl/mkl.BUILD
tensorflow/third_party/mkl/MKL_LICENSE
tensorflow/third_party/mkl/BUILD
tensorflow/third_party/nccl/build_defs.bzl.tpl
tensorflow/third_party/nccl/LICENSE
tensorflow/third_party/nccl/archive.patch
tensorflow/third_party/nccl/nccl_configure.bzl
tensorflow/third_party/nccl/archive.BUILD
tensorflow/third_party/nccl/BUILD
tensorflow/third_party/nccl/system.BUILD.tpl
tensorflow/third_party/snappy.BUILD
tensorflow/third_party/eigen3/unsupported/Eigen/SpecialFunctions
tensorflow/third_party/eigen3/unsupported/Eigen/MatrixFunctions
tensorflow/third_party/eigen3/gpu_packet_math.patch
tensorflow/third_party/eigen3/LICENSE
tensorflow/third_party/eigen3/BUILD
tensorflow/third_party/systemlibs/build_defs.bzl.tpl
tensorflow/third_party/systemlibs/absl_py.BUILD
tensorflow/third_party/systemlibs/curl.BUILD
tensorflow/third_party/systemlibs/termcolor.BUILD
tensorflow/third_party/systemlibs/absl_py.absl.flags.BUILD
tensorflow/third_party/systemlibs/grpc.BUILD
tensorflow/third_party/systemlibs/swig.BUILD
tensorflow/third_party/systemlibs/protobuf.bzl
tensorflow/third_party/systemlibs/protobuf.BUILD
tensorflow/third_party/systemlibs/BUILD
tensorflow/third_party/systemlibs/google_cloud_cpp.BUILD
tensorflow/third_party/systemlibs/astor.BUILD
tensorflow/third_party/systemlibs/six.BUILD
tensorflow/third_party/systemlibs/absl_py.absl.testing.BUILD
tensorflow/third_party/systemlibs/boringssl.BUILD
tensorflow/third_party/systemlibs/nsync.BUILD
tensorflow/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD
tensorflow/third_party/systemlibs/gif.BUILD
tensorflow/third_party/systemlibs/pcre.BUILD
tensorflow/third_party/systemlibs/BUILD.tpl
tensorflow/third_party/systemlibs/snappy.BUILD
tensorflow/third_party/systemlibs/gast.BUILD
tensorflow/third_party/systemlibs/cython.BUILD
tensorflow/third_party/systemlibs/double_conversion.BUILD
tensorflow/third_party/systemlibs/zlib.BUILD
tensorflow/third_party/systemlibs/jsoncpp.BUILD
tensorflow/third_party/systemlibs/re2.BUILD
tensorflow/third_party/systemlibs/lmdb.BUILD
tensorflow/third_party/systemlibs/googleapis.BUILD
tensorflow/third_party/systemlibs/png.BUILD
tensorflow/third_party/systemlibs/syslibs_configure.bzl
tensorflow/third_party/systemlibs/sqlite.BUILD
tensorflow/third_party/python_runtime/BUILD
tensorflow/third_party/googleapis.BUILD
tensorflow/third_party/wrapt.BUILD
tensorflow/third_party/sycl/crosstool/BUILD
tensorflow/third_party/ngraph/LICENSE
tensorflow/third_party/ngraph/tbb.BUILD
tensorflow/third_party/ngraph/BUILD
tensorflow/third_party/ngraph/ngraph.BUILD
tensorflow/third_party/ngraph/build_defs.bzl
tensorflow/third_party/ngraph/NGRAPH_LICENSE
tensorflow/third_party/ngraph/ngraph_tf.BUILD
tensorflow/third_party/ngraph/nlohmann_json.BUILD
tensorflow/third_party/clang_toolchain/download_clang.bzl
tensorflow/third_party/clang_toolchain/BUILD
tensorflow/third_party/clang_toolchain/cc_configure_clang.bzl
tensorflow/third_party/gast.BUILD
tensorflow/third_party/llvm/BUILD
tensorflow/third_party/llvm/expand_cmake_vars.py
tensorflow/third_party/llvm/llvm.autogenerated.BUILD
tensorflow/third_party/llvm/llvm.bzl
tensorflow/third_party/icu/udata.patch
tensorflow/third_party/fft2d/fft2d.h
tensorflow/third_party/fft2d/BUILD
tensorflow/third_party/fft2d/fft.h
tensorflow/third_party/fft2d/LICENSE
tensorflow/third_party/fft2d/fft2d.BUILD
tensorflow/third_party/nccl/archive.BUILD
tensorflow/third_party/nccl/LICENSE
tensorflow/third_party/nccl/system.BUILD.tpl
tensorflow/third_party/nccl/nccl_configure.bzl
tensorflow/third_party/nccl/build_defs.bzl.tpl
tensorflow/third_party/nccl/archive.patch
tensorflow/third_party/nccl/BUILD
tensorflow/third_party/boringssl/BUILD
tensorflow/third_party/protobuf/BUILD
tensorflow/third_party/backports_weakref.BUILD
tensorflow/third_party/mpi/.gitignore
tensorflow/third_party/mpi/BUILD
tensorflow/third_party/tensorrt/LICENSE
tensorflow/third_party/tensorrt/BUILD
tensorflow/third_party/tensorrt/build_defs.bzl.tpl
tensorflow/third_party/tensorrt/BUILD.tpl
tensorflow/third_party/tensorrt/tensorrt_configure.bzl
tensorflow/third_party/tensorrt/tensorrt/include/tensorrt_config.h.tpl
tensorflow/third_party/kafka/config.patch
tensorflow/third_party/kafka/BUILD
tensorflow/third_party/android/BUILD
tensorflow/third_party/android/android.bzl.tpl
tensorflow/third_party/android/android_configure.bzl
tensorflow/third_party/android/android_configure.BUILD.tpl
tensorflow/third_party/tflite_smartreply.BUILD
tensorflow/third_party/mkl_dnn/LICENSE
tensorflow/third_party/mkl_dnn/mkldnn.BUILD
tensorflow/third_party/pcre.BUILD
tensorflow/third_party/pybind11.BUILD
tensorflow/third_party/linenoise.BUILD
tensorflow/third_party/sqlite.BUILD
tensorflow/third_party/common.bzl
tensorflow/third_party/com_google_absl.BUILD
tensorflow/third_party/pprof.BUILD
tensorflow/third_party/BUILD
tensorflow/third_party/tflite_mobilenet_quant.BUILD
tensorflow/third_party/wrapt.BUILD
tensorflow/third_party/lmdb.BUILD
tensorflow/third_party/git/BUILD.tpl
tensorflow/third_party/git/BUILD
tensorflow/third_party/git/git_configure.bzl
tensorflow/third_party/protobuf/BUILD
tensorflow/third_party/enum34.BUILD
tensorflow/third_party/tflite_mobilenet.BUILD
tensorflow/third_party/py/BUILD
tensorflow/third_party/py/BUILD.tpl
tensorflow/third_party/py/numpy/BUILD
tensorflow/third_party/py/python_configure.bzl
tensorflow/third_party/termcolor.BUILD
tensorflow/third_party/png_fix_rpi.patch
tensorflow/third_party/swig.BUILD
tensorflow/compat_template.__init__.py
tensorflow/tools/lib_package/libtensorflow_test.sh
tensorflow/tools/lib_package/libtensorflow_java_test.sh
tensorflow/tools/lib_package/libtensorflow_test.c
tensorflow/tools/lib_package/concat_licenses.sh
tensorflow/tools/lib_package/LibTensorFlowTest.java
tensorflow/tools/lib_package/BUILD
tensorflow/tools/lib_package/README.md
tensorflow/tools/pip_package/check_load_py_test.py
tensorflow/tools/pip_package/simple_console.py
tensorflow/tools/pip_package/pip_smoke_test.py
tensorflow/tools/pip_package/BUILD
tensorflow/tools/pip_package/simple_console_for_windows.py
tensorflow/tools/pip_package/build_pip_package.sh
tensorflow/tools/pip_package/README
tensorflow/tools/pip_package/setup.py
tensorflow/tools/pip_package/MANIFEST.in
tensorflow/tools/ci_build/remote/BUILD
tensorflow/tools/def_file_filter/def_file_filter.py.tpl
tensorflow/tools/def_file_filter/BUILD.tpl
tensorflow/tools/def_file_filter/BUILD
tensorflow/tools/def_file_filter/def_file_filter_configure.bzl
tensorflow/api_template.__init__.py
tensorflow/contrib/tpu/profiler/pip_package/BUILD
tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/__init__.py
tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
tensorflow/contrib/tpu/profiler/pip_package/build_pip_package.sh
tensorflow/contrib/tpu/profiler/pip_package/README
tensorflow/contrib/tpu/profiler/pip_package/setup.py
tensorflow/contrib/mpi/BUILD
tensorflow/python/autograph/core/config.py
tensorflow/virtual_root_template_v2.__init__.py
tensorflow/__init__.py
tensorflow/stream_executor/build_defs.bzl
tensorflow/third_party/astor.BUILD
tensorflow/third_party/grpc/BUILD
tensorflow/third_party/curl.BUILD
tensorflow/third_party/arm_neon_2_x86_sse.BUILD
tensorflow/third_party/png.BUILD
tensorflow/third_party/googleapis.BUILD
tensorflow/third_party/mpi_collectives/BUILD
tensorflow/third_party/nanopb.BUILD
tensorflow/third_party/gif.BUILD
tensorflow/third_party/double_conversion.BUILD
tensorflow/third_party/six.BUILD
tensorflow/third_party/tflite_mobilenet_float.BUILD
tensorflow/third_party/repo.bzl
tensorflow/third_party/codegen.BUILD
tensorflow/third_party/cub.BUILD
tensorflow/third_party/jsoncpp.BUILD
tensorflow/third_party/tflite_ovic_testdata.BUILD
tensorflow/third_party/__init__.py
tensorflow/third_party/libxsmm.BUILD
tensorflow/third_party/zlib.BUILD
tensorflow/third_party/eigen.BUILD
tensorflow/api_template_v1.__init__.py
tensorflow/compat_template_v1.__init__.py
tensorflow/compat_template.__init__.py
tensorflow/api_template.__init__.py
tensorflow/__init__.py
tensorflow/virtual_root_template_v2.__init__.py
tensorflow/virtual_root_template_v1.__init__.py

View File

@ -276,9 +276,12 @@ tf_py_test(
additional_deps = [
":test_base",
"//third_party/py/numpy",
"//tensorflow/python/ops/ragged:ragged_factory_ops",
"//tensorflow/python/ops/ragged:ragged_tensor",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:tensor_shape",
"//tensorflow/python/data/ops:dataset_ops",
@ -709,12 +712,12 @@ py_library(
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/eager:context",
"//tensorflow/python/ops/ragged",
"//tensorflow/python/ops/ragged:ragged_test_util",
],
)

View File

@ -258,7 +258,7 @@ class FromTensorSlicesTest(test_base.DatasetTestBase):
if sparse_tensor.is_sparse(component):
self.assertSparseValuesEqual(component, result_component)
elif ragged_tensor.is_ragged(component):
self.assertRaggedEqual(component, result_component)
self.assertAllEqual(component, result_component)
else:
self.assertAllEqual(component, result_component)
with self.assertRaises(errors.OutOfRangeError):

View File

@ -30,11 +30,10 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_test_util
from tensorflow.python.platform import test
class DatasetTestBase(ragged_test_util.RaggedTensorTestCase, test.TestCase):
class DatasetTestBase(test.TestCase):
"""Base class for dataset tests."""
@classmethod
@ -108,7 +107,7 @@ class DatasetTestBase(ragged_test_util.RaggedTensorTestCase, test.TestCase):
if sparse_tensor.is_sparse(result_value):
self.assertSparseValuesEqual(result_value, expected_value)
elif ragged_tensor.is_ragged(result_value):
self.assertRaggedEqual(result_value, expected_value)
self.assertAllEqual(result_value, expected_value)
else:
self.assertAllEqual(
result_value,
@ -209,7 +208,7 @@ class DatasetTestBase(ragged_test_util.RaggedTensorTestCase, test.TestCase):
if sparse_tensor.is_sparse(op1[i]):
self.assertSparseValuesEqual(op1[i], op2[i])
elif ragged_tensor.is_ragged(op1[i]):
self.assertRaggedEqual(op1[i], op2[i])
self.assertAllEqual(op1[i], op2[i])
elif flattened_types[i] == dtypes.string:
self.assertAllEqual(op1[i], op2[i])
else:

View File

@ -96,14 +96,21 @@ py_test(
":structure",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:tensor_array_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_spec",
"//tensorflow/python:type_spec",
"//tensorflow/python:variables",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/ops/ragged",
"//tensorflow/python/ops/ragged:ragged_test_util",
"//tensorflow/python/ops/ragged:ragged_factory_ops",
"//tensorflow/python/ops/ragged:ragged_tensor",
"//tensorflow/python/ops/ragged:ragged_tensor_value",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
)

View File

@ -37,7 +37,6 @@ from tensorflow.python.ops import variables
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_tensor_value
from tensorflow.python.ops.ragged import ragged_test_util
from tensorflow.python.platform import test
@ -47,7 +46,7 @@ from tensorflow.python.platform import test
#
# TODO(jsimsa): Add tests for OptionalStructure and DatasetStructure.
class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
ragged_test_util.RaggedTensorTestCase):
test_util.TensorFlowTestCase):
# pylint: disable=g-long-lambda,protected-access
@parameterized.named_parameters(
@ -348,7 +347,7 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
elif isinstance(
b,
(ragged_tensor.RaggedTensor, ragged_tensor_value.RaggedTensorValue)):
self.assertRaggedEqual(b, a)
self.assertAllEqual(b, a)
else:
self.assertAllEqual(b, a)
@ -707,7 +706,7 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
if sparse_tensor.is_sparse(expected):
self.assertSparseValuesEqual(expected, actual)
elif ragged_tensor.is_ragged(expected):
self.assertRaggedEqual(expected, actual)
self.assertAllEqual(expected, actual)
else:
self.assertAllEqual(expected, actual)

View File

@ -1900,7 +1900,8 @@ class TensorFlowTestCase(googletest.TestCase):
tensor.dense_shape.numpy())
elif ragged_tensor.is_ragged(tensor):
return ragged_tensor_value.RaggedTensorValue(
tensor.values.numpy(), tensor.row_splits.numpy())
self._eval_tensor(tensor.values),
self._eval_tensor(tensor.row_splits))
elif isinstance(tensor, ops.IndexedSlices):
return ops.IndexedSlicesValue(
values=tensor.values.numpy(),
@ -2363,6 +2364,8 @@ class TensorFlowTestCase(googletest.TestCase):
to the nested structure, e.g. given `a = [(1, 1), {'d': (6, 7)}]` and
`[p] = [1]['d']`, then `a[p] = (6, 7)`.
"""
if ragged_tensor.is_ragged(a) or ragged_tensor.is_ragged(b):
return self._assertRaggedClose(a, b, rtol, atol, msg)
self._assertAllCloseRecursive(a, b, rtol=rtol, atol=atol, msg=msg)
@py_func_if_in_function
@ -2441,6 +2444,8 @@ class TensorFlowTestCase(googletest.TestCase):
b: the actual numpy ndarray or anything can be converted to one.
msg: Optional message to report on failure.
"""
if (ragged_tensor.is_ragged(a) or ragged_tensor.is_ragged(b)):
return self._assertRaggedEqual(a, b, msg)
msg = msg if msg else ""
a = self._GetNdArray(a)
b = self._GetNdArray(b)
@ -2730,6 +2735,51 @@ class TensorFlowTestCase(googletest.TestCase):
device1, device2,
"Devices %s and %s are not equal. %s" % (device1, device2, msg))
def _GetPyList(self, a):
"""Converts `a` to a nested python list."""
if isinstance(a, ragged_tensor.RaggedTensor):
return self.evaluate(a).to_list()
elif isinstance(a, ops.Tensor):
a = self.evaluate(a)
return a.tolist() if isinstance(a, np.ndarray) else a
elif isinstance(a, np.ndarray):
return a.tolist()
elif isinstance(a, ragged_tensor_value.RaggedTensorValue):
return a.to_list()
else:
return np.array(a).tolist()
def _assertRaggedEqual(self, a, b, msg):
"""Asserts that two ragged tensors are equal."""
a_list = self._GetPyList(a)
b_list = self._GetPyList(b)
self.assertEqual(a_list, b_list, msg)
if not (isinstance(a, (list, tuple)) or isinstance(b, (list, tuple))):
a_ragged_rank = a.ragged_rank if ragged_tensor.is_ragged(a) else 0
b_ragged_rank = b.ragged_rank if ragged_tensor.is_ragged(b) else 0
self.assertEqual(a_ragged_rank, b_ragged_rank, msg)
def _assertRaggedClose(self, a, b, rtol, atol, msg=None):
a_list = self._GetPyList(a)
b_list = self._GetPyList(b)
self._assertListCloseRecursive(a_list, b_list, rtol, atol, msg)
if not (isinstance(a, (list, tuple)) or isinstance(b, (list, tuple))):
a_ragged_rank = a.ragged_rank if ragged_tensor.is_ragged(a) else 0
b_ragged_rank = b.ragged_rank if ragged_tensor.is_ragged(b) else 0
self.assertEqual(a_ragged_rank, b_ragged_rank, msg)
def _assertListCloseRecursive(self, a, b, rtol, atol, msg, path="value"):
self.assertEqual(type(a), type(b))
if isinstance(a, (list, tuple)):
self.assertLen(a, len(b), "Length differs for %s" % path)
for i in range(len(a)):
self._assertListCloseRecursive(a[i], b[i], rtol, atol, msg,
"%s[%s]" % (path, i))
else:
self._assertAllCloseRecursive(a, b, rtol, atol, path, msg)
# Fix Python 3 compatibility issues
if six.PY3:
# pylint: disable=invalid-name

View File

@ -1158,7 +1158,6 @@ tf_py_test(
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
"//tensorflow/python/ops/ragged:ragged_tensor",
"//tensorflow/python/ops/ragged:ragged_test_util",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
@ -1679,11 +1678,13 @@ tf_py_test(
additional_deps = [
":keras",
"@absl_py//absl/testing:parameterized",
"//tensorflow/python/ops/ragged:ragged_tensor",
"//tensorflow/python:constant_op",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:ops",
"//tensorflow/python:platform_test",
"//tensorflow/python/eager:context",
"//tensorflow/python/ops/ragged:ragged_factory_ops",
"//tensorflow/python/ops/ragged:ragged_test_util",
],
)

View File

@ -40,7 +40,6 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_test_util
from tensorflow.python.platform import test
@ -155,8 +154,7 @@ def get_model_from_layers_with_input(layers,
@keras_parameterized.run_with_all_model_types
@keras_parameterized.run_all_keras_modes
class CompositeTensorInternalTest(keras_parameterized.TestCase,
ragged_test_util.RaggedTensorTestCase):
class CompositeTensorInternalTest(keras_parameterized.TestCase):
def test_internal_ragged_tensors(self):
# Create a model that accepts an input, converts it to Ragged, and
@ -207,8 +205,7 @@ class CompositeTensorInternalTest(keras_parameterized.TestCase,
@keras_parameterized.run_with_all_model_types
@keras_parameterized.run_all_keras_modes
class CompositeTensorOutputTest(keras_parameterized.TestCase,
ragged_test_util.RaggedTensorTestCase):
class CompositeTensorOutputTest(keras_parameterized.TestCase):
def test_ragged_tensor_outputs(self):
# Create a model that accepts an input, converts it to Ragged, and
@ -221,7 +218,7 @@ class CompositeTensorOutputTest(keras_parameterized.TestCase,
output = model.predict(input_data)
expected_values = [[1], [2, 3]]
self.assertRaggedEqual(expected_values, output)
self.assertAllEqual(expected_values, output)
def test_ragged_tensor_rebatched_outputs(self):
# Create a model that accepts an input, converts it to Ragged, and
@ -234,7 +231,7 @@ class CompositeTensorOutputTest(keras_parameterized.TestCase,
output = model.predict(input_data, batch_size=2)
expected_values = [[1], [2, 3], [4], [5, 6]]
self.assertRaggedEqual(expected_values, output)
self.assertAllEqual(expected_values, output)
def test_sparse_tensor_outputs(self):
# Create a model that accepts an input, converts it to Ragged, and
@ -315,8 +312,7 @@ def prepare_inputs(data, use_dict, use_dataset, action, input_name):
use_dict=[True, False],
use_dataset=[True, False],
action=["predict", "evaluate", "fit"]))
class SparseTensorInputTest(keras_parameterized.TestCase,
ragged_test_util.RaggedTensorTestCase):
class SparseTensorInputTest(keras_parameterized.TestCase):
def test_sparse_tensors(self, use_dict, use_dataset, action):
data = [(sparse_tensor.SparseTensor([[0, 0, 0], [1, 0, 0], [1, 0, 1]],
@ -357,7 +353,7 @@ class SparseTensorInputTest(keras_parameterized.TestCase,
@keras_parameterized.run_with_all_model_types
@keras_parameterized.run_all_keras_modes
class ScipySparseTensorInputTest(keras_parameterized.TestCase,
ragged_test_util.RaggedTensorTestCase):
test_util.TensorFlowTestCase):
def test_sparse_scipy_predict_inputs_via_input_layer_args(self):
# Create a model that accepts a sparse input and converts the sparse tensor
@ -473,7 +469,7 @@ class ScipySparseTensorInputTest(keras_parameterized.TestCase,
use_dataset=[True, False],
action=["predict", "evaluate", "fit"]))
class RaggedTensorInputTest(keras_parameterized.TestCase,
ragged_test_util.RaggedTensorTestCase):
test_util.TensorFlowTestCase):
def test_ragged_input(self, use_dict, use_dataset, action):
data = [(ragged_factory_ops.constant([[[1]], [[2, 3]]]),
@ -510,7 +506,7 @@ class RaggedTensorInputTest(keras_parameterized.TestCase,
*test_util.generate_combinations_with_testcase_name(
use_dict=[True, False], use_dataset=[True, False]))
class RaggedTensorInputValidationTest(keras_parameterized.TestCase,
ragged_test_util.RaggedTensorTestCase):
test_util.TensorFlowTestCase):
def test_ragged_tensor_input_with_one_none_dimension(self, use_dict,
use_dataset):
@ -596,8 +592,7 @@ class RaggedTensorInputValidationTest(keras_parameterized.TestCase,
# subclassed models, so we run a separate parameterized test for them.
@keras_parameterized.run_with_all_model_types(exclude_models=["subclass"])
@keras_parameterized.run_all_keras_modes(always_skip_eager=True)
class SparseTensorInputValidationTest(keras_parameterized.TestCase,
ragged_test_util.RaggedTensorTestCase):
class SparseTensorInputValidationTest(keras_parameterized.TestCase):
def test_sparse_scipy_input_checks_shape(self):
model_input = input_layer.Input(shape=(3,), sparse=True, dtype=dtypes.int32)
@ -647,8 +642,7 @@ class SparseTensorInputValidationTest(keras_parameterized.TestCase,
@keras_parameterized.run_with_all_model_types(
exclude_models=["functional"])
@keras_parameterized.run_all_keras_modes
class UndefinedCompositeTensorInputsTest(keras_parameterized.TestCase,
ragged_test_util.RaggedTensorTestCase):
class UndefinedCompositeTensorInputsTest(keras_parameterized.TestCase):
def test_subclass_implicit_sparse_inputs_fails(self):
# Create a model that accepts a sparse input and converts the sparse tensor

View File

@ -25,13 +25,11 @@ from tensorflow.python.framework import test_util
from tensorflow.python.keras.utils import metrics_utils
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_test_util
from tensorflow.python.platform import googletest
@test_util.run_all_in_graph_and_eager_modes
class RaggedSizeOpTest(ragged_test_util.RaggedTensorTestCase,
parameterized.TestCase):
class RaggedSizeOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
@parameterized.parameters([
{

View File

@ -1077,14 +1077,17 @@ tf_py_test(
size = "small",
srcs = ["string_split_op_test.py"],
additional_deps = [
"@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python/ops/ragged:ragged_factory_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:string_ops",
"//tensorflow/python/ops/ragged:ragged_string_ops",
"//tensorflow/python/ops/ragged:ragged_test_util",
"//tensorflow/python:util",
],
)
@ -1093,14 +1096,17 @@ tf_py_test(
size = "small",
srcs = ["string_bytes_split_op_test.py"],
additional_deps = [
"@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python/ops/ragged:ragged_factory_ops",
"//tensorflow/python/ops/ragged:ragged_string_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:string_ops",
"//tensorflow/python/ops/ragged",
"//tensorflow/python/ops/ragged:ragged_test_util",
],
)
@ -1327,11 +1333,14 @@ tf_py_test(
additional_deps = [
"@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python/ops/ragged:ragged_tensor",
"//tensorflow/python/ops/ragged:ragged_tensor_value",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:errors",
"//tensorflow/python/ops/ragged:ragged_factory_ops",
"//tensorflow/python/ops/ragged:ragged_string_ops",
"//tensorflow/python:framework_test_lib",
],
)
@ -1353,13 +1362,14 @@ tf_py_test(
srcs = ["unicode_decode_op_test.py"],
additional_deps = [
"@absl_py//absl/testing:parameterized",
"//tensorflow/python/eager:context",
"//tensorflow/python/ops/ragged:ragged_factory_ops",
"//tensorflow/python/ops/ragged:ragged_string_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python/ops/ragged:ragged",
"//tensorflow/python/ops/ragged:ragged_test_util",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:string_ops",

View File

@ -22,13 +22,13 @@ from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.framework import test_util
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_string_ops
from tensorflow.python.ops.ragged import ragged_test_util
from tensorflow.python.platform import test
class StringsToBytesOpTest(ragged_test_util.RaggedTensorTestCase,
class StringsToBytesOpTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
@parameterized.parameters(
@ -62,7 +62,7 @@ class StringsToBytesOpTest(ragged_test_util.RaggedTensorTestCase,
def testStringToBytes(self, source, expected):
expected = ragged_factory_ops.constant_value(expected, dtype=object)
result = ragged_string_ops.string_bytes_split(source)
self.assertRaggedEqual(expected, result)
self.assertAllEqual(expected, result)
if __name__ == '__main__':

View File

@ -29,7 +29,6 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_string_ops
from tensorflow.python.ops.ragged import ragged_test_util
from tensorflow.python.platform import test
from tensorflow.python.util import compat
@ -171,8 +170,7 @@ class StringSplitOpTest(test.TestCase):
self.assertAllEqual(shape, [3, 1])
class StringSplitV2OpTest(ragged_test_util.RaggedTensorTestCase,
parameterized.TestCase):
class StringSplitV2OpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
@parameterized.named_parameters([
{"testcase_name": "Simple",
@ -278,11 +276,11 @@ class StringSplitV2OpTest(ragged_test_util.RaggedTensorTestCase,
actual_ragged_v2 = ragged_string_ops.string_split_v2(input, **kwargs)
actual_ragged_v2_input_kwarg = ragged_string_ops.string_split_v2(
input=input, **kwargs)
self.assertRaggedEqual(expected_ragged, actual_ragged_v1)
self.assertRaggedEqual(expected_ragged, actual_ragged_v1_input_kwarg)
self.assertRaggedEqual(expected_ragged, actual_ragged_v1_source_kwarg)
self.assertRaggedEqual(expected_ragged, actual_ragged_v2)
self.assertRaggedEqual(expected_ragged, actual_ragged_v2_input_kwarg)
self.assertAllEqual(expected_ragged, actual_ragged_v1)
self.assertAllEqual(expected_ragged, actual_ragged_v1_input_kwarg)
self.assertAllEqual(expected_ragged, actual_ragged_v1_source_kwarg)
self.assertAllEqual(expected_ragged, actual_ragged_v2)
self.assertAllEqual(expected_ragged, actual_ragged_v2_input_kwarg)
# Check that the internal version (which returns a SparseTensor) works
# correctly. Note: the internal version oly supports vector inputs.

View File

@ -31,7 +31,6 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_string_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_string_ops
from tensorflow.python.ops.ragged import ragged_test_util
from tensorflow.python.platform import test
@ -91,7 +90,7 @@ def _make_sparse_tensor(indices, values, dense_shape, dtype=np.int32):
@test_util.run_all_in_graph_and_eager_modes
class UnicodeDecodeTest(ragged_test_util.RaggedTensorTestCase,
class UnicodeDecodeTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
def testScalarDecode(self):
@ -110,15 +109,15 @@ class UnicodeDecodeTest(ragged_test_util.RaggedTensorTestCase,
chars = ragged_string_ops.unicode_decode(text, "utf-8")
expected_chars = [[ord(c) for c in u"仅今年前"],
[ord(c) for c in u"hello"]]
self.assertRaggedEqual(chars, expected_chars)
self.assertAllEqual(chars, expected_chars)
def testVectorDecodeWithOffset(self):
text = constant_op.constant([u"仅今年前".encode("utf-8"), b"hello"])
chars, starts = ragged_string_ops.unicode_decode_with_offsets(text, "utf-8")
expected_chars = [[ord(c) for c in u"仅今年前"],
[ord(c) for c in u"hello"]]
self.assertRaggedEqual(chars, expected_chars)
self.assertRaggedEqual(starts, [[0, 3, 6, 9], [0, 1, 2, 3, 4]])
self.assertAllEqual(chars, expected_chars)
self.assertAllEqual(starts, [[0, 3, 6, 9], [0, 1, 2, 3, 4]])
@parameterized.parameters([
{"texts": u"仅今年前"},
@ -134,7 +133,7 @@ class UnicodeDecodeTest(ragged_test_util.RaggedTensorTestCase,
_nested_encode(texts, "UTF-8"), ragged_rank=ragged_rank, dtype=bytes)
result = ragged_string_ops.unicode_decode(input_tensor, "UTF-8")
expected = _nested_codepoints(texts)
self.assertRaggedEqual(expected, result)
self.assertAllEqual(expected, result)
@parameterized.parameters([
{"texts": u"仅今年前"},
@ -152,19 +151,19 @@ class UnicodeDecodeTest(ragged_test_util.RaggedTensorTestCase,
input_tensor, "UTF-8")
expected_codepoints = _nested_codepoints(texts)
expected_offsets = _nested_offsets(texts, "UTF-8")
self.assertRaggedEqual(expected_codepoints, result[0])
self.assertRaggedEqual(expected_offsets, result[1])
self.assertAllEqual(expected_codepoints, result[0])
self.assertAllEqual(expected_offsets, result[1])
def testDocstringExamples(self):
texts = [s.encode("utf8") for s in [u"G\xf6\xf6dnight", u"\U0001f60a"]]
codepoints1 = ragged_string_ops.unicode_decode(texts, "UTF-8")
codepoints2, offsets = ragged_string_ops.unicode_decode_with_offsets(
texts, "UTF-8")
self.assertRaggedEqual(
self.assertAllEqual(
codepoints1, [[71, 246, 246, 100, 110, 105, 103, 104, 116], [128522]])
self.assertRaggedEqual(
self.assertAllEqual(
codepoints2, [[71, 246, 246, 100, 110, 105, 103, 104, 116], [128522]])
self.assertRaggedEqual(offsets, [[0, 1, 3, 5, 6, 7, 8, 9, 10], [0]])
self.assertAllEqual(offsets, [[0, 1, 3, 5, 6, 7, 8, 9, 10], [0]])
@parameterized.parameters([
dict(
@ -263,7 +262,7 @@ class UnicodeDecodeTest(ragged_test_util.RaggedTensorTestCase,
]) # pyformat: disable
def testErrorModes(self, expected=None, **args):
result = ragged_string_ops.unicode_decode(**args)
self.assertRaggedEqual(expected, result)
self.assertAllEqual(expected, result)
@parameterized.parameters([
dict(
@ -314,8 +313,8 @@ class UnicodeDecodeTest(ragged_test_util.RaggedTensorTestCase,
expected_offsets=None,
**args):
result = ragged_string_ops.unicode_decode_with_offsets(**args)
self.assertRaggedEqual(result[0], expected)
self.assertRaggedEqual(result[1], expected_offsets)
self.assertAllEqual(result[0], expected)
self.assertAllEqual(result[1], expected_offsets)
@parameterized.parameters(
("UTF-8", [u"こんにちは", u"你好", u"Hello"]),
@ -329,7 +328,7 @@ class UnicodeDecodeTest(ragged_test_util.RaggedTensorTestCase,
expected = _nested_codepoints(texts)
input_tensor = constant_op.constant(_nested_encode(texts, encoding))
result = ragged_string_ops.unicode_decode(input_tensor, encoding)
self.assertRaggedEqual(expected, result)
self.assertAllEqual(expected, result)
@parameterized.parameters(
("UTF-8", [u"こんにちは", u"你好", u"Hello"]),
@ -345,8 +344,8 @@ class UnicodeDecodeTest(ragged_test_util.RaggedTensorTestCase,
input_tensor = constant_op.constant(_nested_encode(texts, encoding))
result = ragged_string_ops.unicode_decode_with_offsets(
input_tensor, encoding)
self.assertRaggedEqual(expected_codepoints, result[0])
self.assertRaggedEqual(expected_offsets, result[1])
self.assertAllEqual(expected_codepoints, result[0])
self.assertAllEqual(expected_offsets, result[1])
@parameterized.parameters([
dict(input=[b"\xFEED"],
@ -423,7 +422,7 @@ class UnicodeDecodeTest(ragged_test_util.RaggedTensorTestCase,
@test_util.run_all_in_graph_and_eager_modes
class UnicodeSplitTest(ragged_test_util.RaggedTensorTestCase,
class UnicodeSplitTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
def testScalarSplit(self):
@ -442,15 +441,15 @@ class UnicodeSplitTest(ragged_test_util.RaggedTensorTestCase,
chars = ragged_string_ops.unicode_split(text, "UTF-8")
expected_chars = [[c.encode("UTF-8") for c in u"仅今年前"],
[c.encode("UTF-8") for c in u"hello"]]
self.assertRaggedEqual(chars, expected_chars)
self.assertAllEqual(chars, expected_chars)
def testVectorSplitWithOffset(self):
text = constant_op.constant([u"仅今年前".encode("UTF-8"), b"hello"])
chars, starts = ragged_string_ops.unicode_split_with_offsets(text, "UTF-8")
expected_chars = [[c.encode("UTF-8") for c in u"仅今年前"],
[c.encode("UTF-8") for c in u"hello"]]
self.assertRaggedEqual(chars, expected_chars)
self.assertRaggedEqual(starts, [[0, 3, 6, 9], [0, 1, 2, 3, 4]])
self.assertAllEqual(chars, expected_chars)
self.assertAllEqual(starts, [[0, 3, 6, 9], [0, 1, 2, 3, 4]])
@parameterized.parameters([
{"texts": u"仅今年前"},
@ -466,7 +465,7 @@ class UnicodeSplitTest(ragged_test_util.RaggedTensorTestCase,
_nested_encode(texts, "UTF-8"), ragged_rank=ragged_rank, dtype=bytes)
result = ragged_string_ops.unicode_split(input_tensor, "UTF-8")
expected = _nested_splitchars(texts, "UTF-8")
self.assertRaggedEqual(expected, result)
self.assertAllEqual(expected, result)
@parameterized.parameters([
{"texts": u"仅今年前"},
@ -483,23 +482,23 @@ class UnicodeSplitTest(ragged_test_util.RaggedTensorTestCase,
result = ragged_string_ops.unicode_split_with_offsets(input_tensor, "UTF-8")
expected_codepoints = _nested_splitchars(texts, "UTF-8")
expected_offsets = _nested_offsets(texts, "UTF-8")
self.assertRaggedEqual(expected_codepoints, result[0])
self.assertRaggedEqual(expected_offsets, result[1])
self.assertAllEqual(expected_codepoints, result[0])
self.assertAllEqual(expected_offsets, result[1])
def testDocstringExamples(self):
texts = [s.encode("utf8") for s in [u"G\xf6\xf6dnight", u"\U0001f60a"]]
codepoints1 = ragged_string_ops.unicode_split(texts, "UTF-8")
codepoints2, offsets = ragged_string_ops.unicode_split_with_offsets(
texts, "UTF-8")
self.assertRaggedEqual(
self.assertAllEqual(
codepoints1,
[[b"G", b"\xc3\xb6", b"\xc3\xb6", b"d", b"n", b"i", b"g", b"h", b"t"],
[b"\xf0\x9f\x98\x8a"]])
self.assertRaggedEqual(
self.assertAllEqual(
codepoints2,
[[b"G", b"\xc3\xb6", b"\xc3\xb6", b"d", b"n", b"i", b"g", b"h", b"t"],
[b"\xf0\x9f\x98\x8a"]])
self.assertRaggedEqual(offsets, [[0, 1, 3, 5, 6, 7, 8, 9, 10], [0]])
self.assertAllEqual(offsets, [[0, 1, 3, 5, 6, 7, 8, 9, 10], [0]])
@parameterized.parameters([
dict(
@ -604,7 +603,7 @@ class UnicodeSplitTest(ragged_test_util.RaggedTensorTestCase,
]) # pyformat: disable
def testErrorModes(self, expected=None, **args):
result = ragged_string_ops.unicode_split(**args)
self.assertRaggedEqual(expected, result)
self.assertAllEqual(expected, result)
@parameterized.parameters([
dict(
@ -644,8 +643,8 @@ class UnicodeSplitTest(ragged_test_util.RaggedTensorTestCase,
expected_offsets=None,
**args):
result = ragged_string_ops.unicode_split_with_offsets(**args)
self.assertRaggedEqual(expected, result[0])
self.assertRaggedEqual(expected_offsets, result[1])
self.assertAllEqual(expected, result[0])
self.assertAllEqual(expected_offsets, result[1])
@parameterized.parameters(
("UTF-8", [u"こんにちは", u"你好", u"Hello"]),
@ -656,7 +655,7 @@ class UnicodeSplitTest(ragged_test_util.RaggedTensorTestCase,
expected = _nested_splitchars(texts, encoding)
input_tensor = constant_op.constant(_nested_encode(texts, encoding))
result = ragged_string_ops.unicode_split(input_tensor, encoding)
self.assertRaggedEqual(expected, result)
self.assertAllEqual(expected, result)
@parameterized.parameters(
("UTF-8", [u"こんにちは", u"你好", u"Hello"]),
@ -669,8 +668,8 @@ class UnicodeSplitTest(ragged_test_util.RaggedTensorTestCase,
input_tensor = constant_op.constant(_nested_encode(texts, encoding))
result = ragged_string_ops.unicode_split_with_offsets(
input_tensor, encoding)
self.assertRaggedEqual(expected_codepoints, result[0])
self.assertRaggedEqual(expected_offsets, result[1])
self.assertAllEqual(expected_codepoints, result[0])
self.assertAllEqual(expected_offsets, result[1])
@parameterized.parameters([
dict(input=[b"\xFEED"],

View File

@ -33,7 +33,7 @@ from tensorflow.python.platform import test
class UnicodeEncodeOpTest(test.TestCase, parameterized.TestCase):
def assertRaggedEqual(self, rt, expected):
def assertAllEqual(self, rt, expected):
with self.cached_session() as sess:
value = sess.run(rt)
if isinstance(value, np.ndarray):
@ -76,10 +76,7 @@ class UnicodeEncodeOpTest(test.TestCase, parameterized.TestCase):
expected_value = u"Heo".encode(encoding)
unicode_encode_op = ragged_string_ops.unicode_encode(test_value, encoding,
"ignore")
with self.cached_session() as session:
result = session.run(unicode_encode_op)
self.assertIsInstance(result, bytes)
self.assertAllEqual(result, expected_value)
self.assertAllEqual(unicode_encode_op, expected_value)
@parameterized.parameters("UTF-8", "UTF-16-BE", "UTF-32-BE")
@test_util.run_v1_only("b/120545219")
@ -88,20 +85,20 @@ class UnicodeEncodeOpTest(test.TestCase, parameterized.TestCase):
expected_value = u"He\U0000fffd\U0000fffdo".encode(encoding)
unicode_encode_op = ragged_string_ops.unicode_encode(test_value, encoding,
"replace")
self.assertRaggedEqual(unicode_encode_op, expected_value)
self.assertAllEqual(unicode_encode_op, expected_value)
# Test custom replacement character
test_value = np.array([72, 101, 2147483647, -1, 111], np.int32)
expected_value = u"Heooo".encode(encoding)
unicode_encode_op = ragged_string_ops.unicode_encode(test_value, encoding,
"replace", 111)
self.assertRaggedEqual(unicode_encode_op, expected_value)
self.assertAllEqual(unicode_encode_op, expected_value)
# Verify "replace" is default
test_value = np.array([72, 101, 2147483647, -1, 111], np.int32)
expected_value = u"He\U0000fffd\U0000fffdo".encode(encoding)
unicode_encode_op = ragged_string_ops.unicode_encode(test_value, encoding)
self.assertRaggedEqual(unicode_encode_op, expected_value)
self.assertAllEqual(unicode_encode_op, expected_value)
# Replacement_char must be within range
test_value = np.array([72, 101, 2147483647, -1, 111], np.int32)
@ -118,23 +115,23 @@ class UnicodeEncodeOpTest(test.TestCase, parameterized.TestCase):
test_value = np.array([72, 101, 108, 108, 111], np.int32)
expected_value = u"Hello".encode(encoding)
unicode_encode_op = ragged_string_ops.unicode_encode(test_value, encoding)
self.assertRaggedEqual(unicode_encode_op, expected_value)
self.assertAllEqual(unicode_encode_op, expected_value)
test_value = np.array([72, 101, 195, 195, 128516], np.int32)
expected_value = u"He\xc3\xc3\U0001f604".encode(encoding)
unicode_encode_op = ragged_string_ops.unicode_encode(test_value, encoding)
self.assertRaggedEqual(unicode_encode_op, expected_value)
self.assertAllEqual(unicode_encode_op, expected_value)
# Single character string
test_value = np.array([72], np.int32)
expected_value = u"H".encode(encoding)
unicode_encode_op = ragged_string_ops.unicode_encode(test_value, encoding)
self.assertRaggedEqual(unicode_encode_op, expected_value)
self.assertAllEqual(unicode_encode_op, expected_value)
test_value = np.array([128516], np.int32)
expected_value = u"\U0001f604".encode(encoding)
unicode_encode_op = ragged_string_ops.unicode_encode(test_value, encoding)
self.assertRaggedEqual(unicode_encode_op, expected_value)
self.assertAllEqual(unicode_encode_op, expected_value)
@parameterized.parameters("UTF-8", "UTF-16-BE", "UTF-32-BE")
@test_util.run_v1_only("b/120545219")
@ -158,7 +155,7 @@ class UnicodeEncodeOpTest(test.TestCase, parameterized.TestCase):
[u"fixed".encode(encoding), u"words".encode(encoding)],
[u"Hyper".encode(encoding), u"cube.".encode(encoding)]]
unicode_encode_op = ragged_string_ops.unicode_encode(test_value, encoding)
self.assertRaggedEqual(unicode_encode_op, expected_value)
self.assertAllEqual(unicode_encode_op, expected_value)
@parameterized.parameters("UTF-8", "UTF-16-BE", "UTF-32-BE")
@test_util.run_v1_only("b/120545219")
@ -174,7 +171,7 @@ class UnicodeEncodeOpTest(test.TestCase, parameterized.TestCase):
[[u"Hyper".encode(encoding)],
[u"cube.".encode(encoding)]]]
unicode_encode_op = ragged_string_ops.unicode_encode(test_value, encoding)
self.assertRaggedEqual(unicode_encode_op, expected_value)
self.assertAllEqual(unicode_encode_op, expected_value)
# -- Ragged Tensor tests -- #
@ -187,7 +184,7 @@ class UnicodeEncodeOpTest(test.TestCase, parameterized.TestCase):
u"H\xc3llo".encode(encoding), u"W\U0001f604rld.".encode(encoding)
]
unicode_encode_op = ragged_string_ops.unicode_encode(test_value, encoding)
self.assertRaggedEqual(unicode_encode_op, expected_value)
self.assertAllEqual(unicode_encode_op, expected_value)
@parameterized.parameters("UTF-8", "UTF-16-BE", "UTF-32-BE")
@test_util.run_v1_only("b/120545219")
@ -204,7 +201,7 @@ class UnicodeEncodeOpTest(test.TestCase, parameterized.TestCase):
u"cube.".encode(encoding)
]]
unicode_encode_op = ragged_string_ops.unicode_encode(test_value, encoding)
self.assertRaggedEqual(unicode_encode_op, expected_value)
self.assertAllEqual(unicode_encode_op, expected_value)
@parameterized.parameters("UTF-8", "UTF-16-BE", "UTF-32-BE")
@test_util.run_v1_only("b/120545219")
@ -219,7 +216,7 @@ class UnicodeEncodeOpTest(test.TestCase, parameterized.TestCase):
u"w\xc3rry, be".encode(encoding)
], [u"\U0001f604".encode(encoding), u"".encode(encoding)]]
unicode_encode_op = ragged_string_ops.unicode_encode(test_value, encoding)
self.assertRaggedEqual(unicode_encode_op, expected_value)
self.assertAllEqual(unicode_encode_op, expected_value)
@parameterized.parameters("UTF-8", "UTF-16-BE", "UTF-32-BE")
@test_util.run_v1_only("b/120545219")
@ -230,7 +227,7 @@ class UnicodeEncodeOpTest(test.TestCase, parameterized.TestCase):
expected_value = [[u"Hello".encode(encoding), u"World.".encode(encoding)],
[], [u"\U0001f604".encode(encoding)]]
unicode_encode_op = ragged_string_ops.unicode_encode(test_value, encoding)
self.assertRaggedEqual(unicode_encode_op, expected_value)
self.assertAllEqual(unicode_encode_op, expected_value)
@parameterized.parameters("UTF-8", "UTF-16-BE", "UTF-32-BE")
@test_util.run_v1_only("b/120545219")
@ -241,7 +238,7 @@ class UnicodeEncodeOpTest(test.TestCase, parameterized.TestCase):
expected_value = [[[u"Hello".encode(encoding), u"World".encode(encoding)]],
[[u"".encode(encoding)], [u"Hype".encode(encoding)]]]
unicode_encode_op = ragged_string_ops.unicode_encode(test_value, encoding)
self.assertRaggedEqual(unicode_encode_op, expected_value)
self.assertAllEqual(unicode_encode_op, expected_value)
@parameterized.parameters("UTF-8", "UTF-16-BE", "UTF-32-BE")
@test_util.run_v1_only("b/120545219")
@ -264,7 +261,7 @@ class UnicodeEncodeOpTest(test.TestCase, parameterized.TestCase):
[u"Hyper".encode(encoding),
u"cube.".encode(encoding)]]]]
unicode_encode_op = ragged_string_ops.unicode_encode(test_value, encoding)
self.assertRaggedEqual(unicode_encode_op, expected_value)
self.assertAllEqual(unicode_encode_op, expected_value)
if __name__ == "__main__":

View File

@ -53,21 +53,17 @@ py_library(
srcs = ["ragged_array_ops.py"],
srcs_version = "PY2AND3",
deps = [
":ragged_conversion_ops",
":ragged_functional_ops",
":ragged_math_ops",
":ragged_tensor",
":ragged_util",
":segment_id_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:check_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:ragged_array_ops_gen",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_util",
"//tensorflow/python:util",
],
)
@ -76,17 +72,11 @@ py_library(
srcs = ["ragged_batch_gather_ops.py"],
srcs_version = "PY2AND3",
deps = [
":ragged_array_ops",
":ragged_concat_ops",
":ragged_conversion_ops",
":ragged_gather_ops",
":ragged_tensor",
":ragged_tensor_shape",
":ragged_util",
":ragged_where_op",
"//tensorflow/python:array_ops",
"//tensorflow/python:check_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
],
@ -100,8 +90,6 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":ragged_array_ops",
":ragged_batch_gather_ops",
":ragged_concat_ops",
":ragged_dispatch",
":ragged_operators",
":ragged_tensor",
@ -109,7 +97,7 @@ py_library(
":ragged_where_op",
"//tensorflow/python:array_ops",
"//tensorflow/python:check_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:constant_op",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
],
@ -121,7 +109,6 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":ragged_array_ops",
":ragged_conversion_ops",
":ragged_gather_ops",
":ragged_tensor",
":ragged_util",
@ -138,19 +125,8 @@ py_library(
srcs = ["ragged_conversion_ops.py"],
srcs_version = "PY2AND3",
deps = [
":ragged_factory_ops",
":ragged_tensor",
":ragged_util",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:ragged_conversion_ops_gen",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:util",
],
)
@ -158,15 +134,14 @@ py_library(
name = "ragged_factory_ops",
srcs = ["ragged_factory_ops.py"],
deps = [
":ragged_tensor",
":ragged_tensor_value",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:tensor_util",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:util",
"//tensorflow/python/ops/ragged:ragged_tensor",
"//tensorflow/python/ops/ragged:ragged_tensor_value",
"//third_party/py/numpy",
],
)
@ -176,10 +151,12 @@ py_library(
srcs = ["ragged_functional_ops.py"],
srcs_version = "PY2AND3",
deps = [
":ragged_factory_ops",
":ragged_config",
":ragged_tensor",
":ragged_util",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:util",
],
)
@ -190,7 +167,6 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":ragged_array_ops",
":ragged_conversion_ops",
":ragged_tensor",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
@ -206,15 +182,14 @@ py_library(
srcs = ["ragged_getitem.py"],
srcs_version = "PY2AND3",
deps = [
":ragged_factory_ops",
":ragged_gather_ops",
":ragged_math_ops",
":ragged_tensor",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python/eager:context",
],
)
@ -223,7 +198,6 @@ py_library(
srcs = ["ragged_math_ops.py"],
srcs_version = "PY2AND3",
deps = [
":ragged_factory_ops",
":ragged_functional_ops",
":ragged_tensor",
":ragged_util",
@ -236,6 +210,7 @@ py_library(
"//tensorflow/python:ragged_math_ops_gen",
"//tensorflow/python:tensor_util",
"//tensorflow/python:util",
"//third_party/py/numpy",
],
)
@ -257,14 +232,10 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":ragged_array_ops",
":ragged_conversion_ops",
":ragged_factory_ops",
":ragged_tensor",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:string_ops",
"//tensorflow/python:util",
],
@ -276,10 +247,13 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":ragged_tensor",
":ragged_tensor_shape",
":ragged_util",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
],
)
@ -287,9 +261,6 @@ py_library(
name = "ragged_config",
srcs = ["ragged_config.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:dtypes",
],
)
py_library(
@ -301,11 +272,23 @@ py_library(
":ragged_tensor_value",
":ragged_util",
":segment_id_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:check_ops",
"//tensorflow/python:composite_tensor",
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:ragged_conversion_ops_gen",
"//tensorflow/python:session",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_spec",
"//tensorflow/python:tensor_util",
"//tensorflow/python:type_spec",
"//tensorflow/python:util",
"//third_party/py/numpy",
],
)
@ -315,8 +298,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":ragged_array_ops",
":ragged_conversion_ops",
":ragged_factory_ops",
":ragged_config",
":ragged_tensor",
":ragged_util",
"//tensorflow/python:array_ops",
@ -325,6 +307,7 @@ py_library(
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_util",
],
)
@ -363,7 +346,6 @@ py_library(
":ragged_gather_ops",
":ragged_tensor",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
],
@ -374,7 +356,6 @@ py_library(
srcs = ["segment_id_ops.py"],
srcs_version = "PY2AND3",
deps = [
":ragged_config",
":ragged_util",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
@ -391,6 +372,7 @@ py_library(
srcs = ["ragged_map_ops.py"],
srcs_version = "PY2AND3",
deps = [
":ragged_config",
":ragged_tensor",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
@ -415,14 +397,18 @@ py_library(
deps = [
":ragged_array_ops",
":ragged_batch_gather_ops",
":ragged_concat_ops",
":ragged_gather_ops",
":ragged_math_ops",
":ragged_squeeze_op",
":ragged_tensor",
":ragged_tensor_shape",
":ragged_util",
":ragged_where_op",
"//tensorflow/python:array_ops",
"//tensorflow/python:bitwise_ops",
"//tensorflow/python:clip_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:parsing_ops",
@ -438,19 +424,6 @@ py_library(
# RaggedTensor Tests
#-------------------------------------------------------------------------------
py_library(
name = "ragged_test_util",
srcs = ["ragged_test_util.py"],
srcs_version = "PY2AND3",
deps = [
":ragged_tensor",
":ragged_tensor_value",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//third_party/py/numpy",
],
)
py_test(
name = "ragged_tensor_test",
size = "medium",
@ -467,7 +440,6 @@ py_test(
":ragged_math_ops",
":ragged_tensor",
":ragged_tensor_value",
":ragged_test_util",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
@ -475,6 +447,7 @@ py_test(
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//tensorflow/python:tensor_shape",
"//tensorflow/python/eager:context",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
@ -489,8 +462,8 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":ragged_factory_ops",
":ragged_test_util",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"@absl_py//absl/testing:parameterized",
],
@ -503,7 +476,6 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":ragged_math_ops",
":ragged_test_util",
"//tensorflow/python:errors",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
@ -518,7 +490,6 @@ py_test(
deps = [
":ragged_factory_ops",
":ragged_tensor",
":ragged_test_util",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
@ -532,7 +503,6 @@ py_test(
deps = [
":ragged_factory_ops",
":ragged_tensor",
":ragged_test_util",
"//tensorflow/python:errors",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
@ -546,10 +516,8 @@ py_test(
python_version = "PY2",
srcs_version = "PY2AND3",
deps = [
":ragged_array_ops",
":ragged_factory_ops",
":ragged_gather_ops",
":ragged_test_util",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
@ -567,13 +535,12 @@ py_test(
python_version = "PY2",
srcs_version = "PY2AND3",
deps = [
":ragged_array_ops",
":ragged_batch_gather_ops",
":ragged_batch_gather_with_default_op",
":ragged_factory_ops",
":ragged_tensor",
":ragged_test_util",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_test_lib",
@ -591,7 +558,6 @@ py_test(
deps = [
":ragged_factory_ops",
":ragged_gather_ops",
":ragged_test_util",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
@ -609,7 +575,6 @@ py_test(
python_version = "PY2",
srcs_version = "PY2AND3",
deps = [
":ragged_test_util",
":segment_id_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:framework_test_lib",
@ -623,7 +588,6 @@ py_test(
python_version = "PY2",
srcs_version = "PY2AND3",
deps = [
":ragged_test_util",
":segment_id_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:framework_test_lib",
@ -637,15 +601,11 @@ py_test(
python_version = "PY2",
srcs_version = "PY2AND3",
deps = [
":ragged_string_ops",
":ragged_tensor",
":ragged_test_util",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn_ops",
"//tensorflow/python:platform_test",
"@absl_py//absl/testing:parameterized",
],
@ -664,7 +624,6 @@ py_test(
":ragged_factory_ops",
":ragged_functional_ops",
":ragged_tensor",
":ragged_test_util",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
@ -683,7 +642,6 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":ragged_tensor",
":ragged_test_util",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
@ -701,11 +659,10 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":ragged_factory_ops",
":ragged_test_util",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
)
@ -719,7 +676,6 @@ py_test(
":ragged_factory_ops",
":ragged_math_ops",
":ragged_tensor",
":ragged_test_util",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:errors",
@ -737,7 +693,6 @@ py_test(
deps = [
":ragged_factory_ops",
":ragged_math_ops",
":ragged_test_util",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
@ -758,7 +713,6 @@ py_test(
":ragged_factory_ops",
":ragged_functional_ops",
":ragged_tensor",
":ragged_test_util",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
@ -778,10 +732,10 @@ py_test(
":ragged",
":ragged_factory_ops",
":ragged_tensor",
":ragged_test_util",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
)
@ -796,9 +750,7 @@ py_test(
],
deps = [
":ragged_factory_ops",
":ragged_tensor",
":ragged_tensor_value",
":ragged_test_util",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//third_party/py/numpy",
@ -814,7 +766,6 @@ py_test(
deps = [
":ragged_factory_ops",
":ragged_tensor",
":ragged_test_util",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_test_lib",
@ -832,7 +783,6 @@ py_test(
deps = [
":ragged_array_ops",
":ragged_factory_ops",
":ragged_test_util",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
@ -851,7 +801,6 @@ py_test(
deps = [
":ragged_concat_ops",
":ragged_factory_ops",
":ragged_test_util",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
@ -871,7 +820,6 @@ py_test(
deps = [
":ragged_concat_ops",
":ragged_factory_ops",
":ragged_test_util",
"//tensorflow/python:constant_op",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
@ -887,8 +835,8 @@ py_test(
deps = [
":ragged_array_ops",
":ragged_factory_ops",
":ragged_test_util",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"@absl_py//absl/testing:parameterized",
],
)
@ -901,7 +849,6 @@ py_test(
deps = [
":ragged_array_ops",
":ragged_factory_ops",
":ragged_test_util",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
@ -917,7 +864,6 @@ py_test(
python_version = "PY2",
srcs_version = "PY2AND3",
deps = [
":ragged_test_util",
":ragged_util",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
@ -936,7 +882,6 @@ py_test(
deps = [
":ragged_array_ops",
":ragged_factory_ops",
":ragged_test_util",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"@absl_py//absl/testing:parameterized",
@ -950,7 +895,6 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":ragged_factory_ops",
":ragged_test_util",
":ragged_where_op",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
@ -967,8 +911,8 @@ py_test(
":ragged", # fixdeps: keep
":ragged_factory_ops",
":ragged_tensor",
":ragged_test_util",
"//tensorflow/python:array_ops",
"//tensorflow/python:bitwise_ops",
"//tensorflow/python:clip_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
@ -993,7 +937,6 @@ py_test(
deps = [
":ragged", # fixdeps: keep
":ragged_factory_ops",
":ragged_test_util",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
@ -1012,12 +955,12 @@ py_test(
":ragged_map_ops",
":ragged_math_ops",
":ragged_tensor",
":ragged_test_util",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:string_ops",
"//tensorflow/python/keras:backend",
"//third_party/py/numpy",
@ -1035,7 +978,6 @@ py_test(
":ragged_factory_ops",
":ragged_tensor",
":ragged_tensor_shape",
":ragged_test_util",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
@ -1052,8 +994,6 @@ py_test(
deps = [
":ragged_array_ops",
":ragged_factory_ops",
":ragged_test_util",
"//tensorflow/python:constant_op",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"@absl_py//absl/testing:parameterized",
@ -1066,11 +1006,12 @@ py_test(
python_version = "PY2",
srcs_version = "PY2AND3",
deps = [
":ragged_math_ops",
":ragged_test_util",
"//tensorflow/python:errors",
":ragged_factory_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//tensorflow/python/eager:context",
"@absl_py//absl/testing:parameterized",
],
)
@ -1083,8 +1024,9 @@ py_test(
":ragged_conversion_ops",
":ragged_factory_ops",
":ragged_squeeze_op",
":ragged_test_util",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:errors",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"@absl_py//absl/testing:parameterized",

View File

@ -26,13 +26,12 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_test_util
from tensorflow.python.platform import googletest
@test_util.run_all_in_graph_and_eager_modes
class RaggedConvertToTensorOrRaggedTensorTest(
ragged_test_util.RaggedTensorTestCase, parameterized.TestCase):
class RaggedConvertToTensorOrRaggedTensorTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
#=============================================================================
# Tests where the 'value' param is a RaggedTensor
@ -126,7 +125,7 @@ class RaggedConvertToTensorOrRaggedTensorTest(
value, dtype, preferred_dtype)
self.assertEqual(value.ragged_rank, converted.ragged_rank)
self.assertEqual(dtypes.as_dtype(expected_dtype), converted.dtype)
self.assertEqual(value.to_list(), self.eval_to_list(converted))
self.assertAllEqual(value, converted)
@parameterized.parameters([
dict(

View File

@ -30,12 +30,11 @@ from tensorflow.python.ops.ragged import ragged_batch_gather_ops
from tensorflow.python.ops.ragged import ragged_batch_gather_with_default_op
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_test_util
from tensorflow.python.platform import googletest
@test_util.run_all_in_graph_and_eager_modes
class RaggedBatchGatherOpTest(ragged_test_util.RaggedTensorTestCase,
class RaggedBatchGatherOpTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
@parameterized.parameters([
@ -149,7 +148,7 @@ class RaggedBatchGatherOpTest(ragged_test_util.RaggedTensorTestCase,
])
def testRaggedBatchGather(self, descr, params, indices, expected):
result = ragged_batch_gather_ops.batch_gather(params, indices)
self.assertRaggedEqual(result, expected)
self.assertAllEqual(result, expected)
@parameterized.parameters([
# Docstring example:
@ -359,7 +358,7 @@ class RaggedBatchGatherOpTest(ragged_test_util.RaggedTensorTestCase,
expected, ragged_rank=expected_ragged_rank or ragged_rank)
result = ragged_batch_gather_with_default_op.batch_gather_with_default(
params, indices, default_value)
self.assertRaggedEqual(result, expected)
self.assertAllEqual(result, expected)
@parameterized.parameters([
# Dimensions:

View File

@ -27,12 +27,11 @@ from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.ragged import ragged_array_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_test_util
from tensorflow.python.platform import googletest
@test_util.run_all_in_graph_and_eager_modes
class RaggedBooleanMaskOpTest(ragged_test_util.RaggedTensorTestCase,
class RaggedBooleanMaskOpTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
# Define short constants for true & false, so the data & mask can be lined
# up in the examples below. This makes it easier to read the examples, to
@ -243,7 +242,7 @@ class RaggedBooleanMaskOpTest(ragged_test_util.RaggedTensorTestCase,
]) # pyformat: disable
def testBooleanMask(self, descr, data, mask, expected):
actual = ragged_array_ops.boolean_mask(data, mask)
self.assertRaggedEqual(actual, expected)
self.assertAllEqual(actual, expected)
def testErrors(self):
if not context.executing_eagerly():

View File

@ -28,12 +28,11 @@ from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.ragged import ragged_concat_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_test_util
from tensorflow.python.platform import googletest
@test_util.run_all_in_graph_and_eager_modes
class RaggedConcatOpTest(ragged_test_util.RaggedTensorTestCase,
class RaggedConcatOpTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
def _rt_inputs_to_tensors(self, rt_inputs, ragged_ranks=None):
@ -240,7 +239,7 @@ class RaggedConcatOpTest(ragged_test_util.RaggedTensorTestCase,
self.assertEqual(concatenated.ragged_rank, expected_ragged_rank)
if expected_shape is not None:
self.assertEqual(concatenated.shape.as_list(), expected_shape)
self.assertRaggedEqual(concatenated, expected)
self.assertAllEqual(concatenated, expected)
@parameterized.parameters(
dict(
@ -318,7 +317,7 @@ class RaggedConcatOpTest(ragged_test_util.RaggedTensorTestCase,
"""
rt_inputs = ragged_factory_ops.constant([[1, 2], [3, 4]])
concatenated = ragged_concat_ops.concat(rt_inputs, 0)
self.assertRaggedEqual(concatenated, [[1, 2], [3, 4]])
self.assertAllEqual(concatenated, [[1, 2], [3, 4]])
if __name__ == '__main__':

View File

@ -26,12 +26,11 @@ from tensorflow.python.framework import test_util
from tensorflow.python.ops import ragged
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_test_util
from tensorflow.python.platform import googletest
@test_util.run_all_in_graph_and_eager_modes
class RaggedConstOpTest(ragged_test_util.RaggedTensorTestCase,
class RaggedConstOpTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
@parameterized.parameters(
@ -203,7 +202,7 @@ class RaggedConstOpTest(ragged_test_util.RaggedTensorTestCase,
pylist, dtype=dtype, ragged_rank=ragged_rank, inner_shape=inner_shape)
# Normalize the pylist, i.e., convert all np.arrays to list.
# E.g., [np.array((1,2))] --> [[1,2]]
pylist = self._normalize_pylist(pylist)
pylist = _normalize_pylist(pylist)
# If dtype was explicitly specified, check it.
if dtype is not None:
@ -227,8 +226,11 @@ class RaggedConstOpTest(ragged_test_util.RaggedTensorTestCase,
if expected_shape is not None:
self.assertEqual(tuple(rt.shape.as_list()), expected_shape)
if (expected_shape and expected_shape[0] == 0 and
None not in expected_shape):
pylist = np.zeros(expected_shape, rt.dtype.as_numpy_dtype)
self.assertRaggedEqual(rt, pylist)
self.assertAllEqual(rt, pylist)
@parameterized.parameters(
dict(
@ -399,5 +401,14 @@ class RaggedConstOpTest(ragged_test_util.RaggedTensorTestCase,
pylist, ragged_rank), inner_shape)
def _normalize_pylist(item):
"""Convert all (possibly nested) np.arrays contained in item to list."""
# convert np.arrays in current level to list
if np.ndim(item) == 0:
return item
level = (x.tolist() if isinstance(x, np.ndarray) else x for x in item)
return [_normalize_pylist(el) if np.ndim(el) != 0 else el for el in level]
if __name__ == '__main__':
googletest.main()

View File

@ -24,14 +24,12 @@ import numpy as np
from tensorflow.python.framework import test_util
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor_value
from tensorflow.python.ops.ragged import ragged_test_util
from tensorflow.python.platform import googletest
@test_util.run_all_in_graph_and_eager_modes
class RaggedConstantValueOpTest(ragged_test_util.RaggedTensorTestCase,
class RaggedConstantValueOpTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
@parameterized.parameters(
#=========================================================================
# 0-dimensional tensors.
@ -190,7 +188,7 @@ class RaggedConstantValueOpTest(ragged_test_util.RaggedTensorTestCase,
pylist, dtype=dtype, ragged_rank=ragged_rank, inner_shape=inner_shape)
# Normalize the pylist, i.e., convert all np.arrays to list.
# E.g., [np.array((1,2))] --> [[1,2]]
pylist = self._normalize_pylist(pylist)
pylist = _normalize_pylist(pylist)
# If dtype was explicitly specified, check it.
if dtype is not None:
self.assertEqual(rt.dtype, dtype)
@ -315,5 +313,14 @@ class RaggedConstantValueOpTest(ragged_test_util.RaggedTensorTestCase,
inner_shape=inner_shape)
def _normalize_pylist(item):
"""Convert all (possibly nested) np.arrays contained in item to list."""
# convert np.arrays in current level to list
if np.ndim(item) == 0:
return item
level = (x.tolist() if isinstance(x, np.ndarray) else x for x in item)
return [_normalize_pylist(el) if np.ndim(el) != 0 else el for el in level]
if __name__ == '__main__':
googletest.main()

View File

@ -35,7 +35,6 @@ from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_test_util
from tensorflow.python.platform import googletest
# Constants listing various op types to test. Each operation
@ -137,7 +136,7 @@ BINARY_INT_OPS = [
@test_util.run_all_in_graph_and_eager_modes
class RaggedElementwiseOpsTest(ragged_test_util.RaggedTensorTestCase,
class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
def assertSameShape(self, x, y):
@ -468,7 +467,7 @@ class RaggedElementwiseOpsTest(ragged_test_util.RaggedTensorTestCase,
x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, dtype=dtypes.int32)
y = ragged_tensor.convert_to_tensor_or_ragged_tensor(y, dtype=dtypes.int32)
result = x + y
self.assertRaggedEqual(result, expected)
self.assertAllEqual(result, expected)
def testElementwiseOpShapeMismatch(self):
x = ragged_factory_ops.constant([[1, 2, 3], [4, 5]])
@ -713,7 +712,7 @@ class RaggedElementwiseOpsTest(ragged_test_util.RaggedTensorTestCase,
def testRaggedDispatch(self, op, expected, args=(), kwargs=None):
if kwargs is None: kwargs = {}
result = op(*args, **kwargs)
self.assertRaggedEqual(result, expected)
self.assertAllEqual(result, expected)
if __name__ == '__main__':

View File

@ -21,12 +21,12 @@ from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_test_util
from tensorflow.python.platform import googletest
class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
class RaggedTensorTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
@parameterized.parameters([
@ -36,7 +36,7 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
])
def testRaggedTensorToList(self, pylist, ragged_rank=None):
rt = ragged_factory_ops.constant(pylist, ragged_rank)
self.assertRaggedEqual(rt, pylist)
self.assertAllEqual(rt, pylist)
@parameterized.parameters([
dict(pylist=[[b'a', b'b'], [b'c']]),

View File

@ -23,12 +23,11 @@ from absl.testing import parameterized
from tensorflow.python.framework import test_util
from tensorflow.python.ops.ragged import ragged_array_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_test_util
from tensorflow.python.platform import googletest
@test_util.run_all_in_graph_and_eager_modes
class RaggedExpandDimsOpTest(ragged_test_util.RaggedTensorTestCase,
class RaggedExpandDimsOpTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
# An example 4-d ragged tensor with shape [3, (D2), (D3), 2], and the
@ -120,7 +119,7 @@ class RaggedExpandDimsOpTest(ragged_test_util.RaggedTensorTestCase,
if expected_shape is not None:
self.assertEqual(expanded.shape.as_list(), expected_shape)
self.assertRaggedEqual(expanded, expected)
self.assertAllEqual(expanded, expected)
if __name__ == '__main__':

View File

@ -24,13 +24,12 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.ragged import ragged_test_util
from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensor
from tensorflow.python.platform import googletest
@test_util.run_all_in_graph_and_eager_modes
class RaggedTensorToSparseOpTest(ragged_test_util.RaggedTensorTestCase):
class RaggedTensorToSparseOpTest(test_util.TensorFlowTestCase):
def testDocStringExample(self):
st = sparse_tensor.SparseTensor(
@ -39,7 +38,7 @@ class RaggedTensorToSparseOpTest(ragged_test_util.RaggedTensorTestCase):
dense_shape=[4, 3])
rt = RaggedTensor.from_sparse(st)
self.assertRaggedEqual(rt, [[1, 2, 3], [4], [], [5]])
self.assertAllEqual(rt, [[1, 2, 3], [4], [], [5]])
def testEmpty(self):
st = sparse_tensor.SparseTensor(
@ -48,7 +47,7 @@ class RaggedTensorToSparseOpTest(ragged_test_util.RaggedTensorTestCase):
dense_shape=[4, 3])
rt = RaggedTensor.from_sparse(st)
self.assertRaggedEqual(rt, [[], [], [], []])
self.assertAllEqual(rt, [[], [], [], []])
def testBadSparseTensorRank(self):
st1 = sparse_tensor.SparseTensor(indices=[[0]], values=[0], dense_shape=[3])

View File

@ -24,31 +24,30 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.ragged import ragged_test_util
from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensor
from tensorflow.python.platform import googletest
@test_util.run_all_in_graph_and_eager_modes
class RaggedTensorToSparseOpTest(ragged_test_util.RaggedTensorTestCase,
class RaggedTensorToSparseOpTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
def testDocStringExamples(self):
# The examples from RaggedTensor.from_tensor.__doc__.
dt = constant_op.constant([[5, 7, 0], [0, 3, 0], [6, 0, 0]])
self.assertRaggedEqual(
self.assertAllEqual(
RaggedTensor.from_tensor(dt), [[5, 7, 0], [0, 3, 0], [6, 0, 0]])
self.assertRaggedEqual(
self.assertAllEqual(
RaggedTensor.from_tensor(dt, lengths=[1, 0, 3]), [[5], [], [6, 0, 0]])
self.assertRaggedEqual(
self.assertAllEqual(
RaggedTensor.from_tensor(dt, padding=0), [[5, 7], [0, 3], [6]])
dt_3d = constant_op.constant([[[5, 0], [7, 0], [0, 0]],
[[0, 0], [3, 0], [0, 0]],
[[6, 0], [0, 0], [0, 0]]])
self.assertRaggedEqual(
self.assertAllEqual(
RaggedTensor.from_tensor(dt_3d, lengths=([2, 0, 3], [1, 1, 2, 0, 1])),
[[[5], [7]], [], [[6, 0], [], [0]]])
@ -320,7 +319,7 @@ class RaggedTensorToSparseOpTest(ragged_test_util.RaggedTensorTestCase,
self.assertTrue(
dt.shape.is_compatible_with(rt.shape),
'%s is incompatible with %s' % (dt.shape, rt.shape))
self.assertRaggedEqual(rt, expected)
self.assertAllEqual(rt, expected)
def testHighDimensions(self):
# Use distinct prime numbers for all dimension shapes in this test, so
@ -334,7 +333,7 @@ class RaggedTensorToSparseOpTest(ragged_test_util.RaggedTensorTestCase,
self.assertTrue(
dt.shape.is_compatible_with(rt.shape),
'%s is incompatible with %s' % (dt.shape, rt.shape))
self.assertRaggedEqual(rt, self.evaluate(dt).tolist())
self.assertAllEqual(rt, self.evaluate(dt).tolist())
@parameterized.parameters(
# With no padding or lengths
@ -444,7 +443,7 @@ class RaggedTensorToSparseOpTest(ragged_test_util.RaggedTensorTestCase,
self.assertEqual(type(rt), RaggedTensor)
self.assertEqual(rt.ragged_rank, 1)
self.assertTrue(dt.shape.is_compatible_with(rt.shape))
self.assertRaggedEqual(rt, expected)
self.assertAllEqual(rt, expected)
@parameterized.parameters(
{

View File

@ -28,12 +28,11 @@ from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_gather_ops
from tensorflow.python.ops.ragged import ragged_test_util
from tensorflow.python.platform import googletest
@test_util.run_all_in_graph_and_eager_modes
class RaggedGatherNdOpTest(ragged_test_util.RaggedTensorTestCase,
class RaggedGatherNdOpTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
DOCSTRING_PARAMS = [[['000', '001'], ['010']],
@ -202,7 +201,7 @@ class RaggedGatherNdOpTest(ragged_test_util.RaggedTensorTestCase,
]) # pyformat: disable
def testRaggedGatherNd(self, descr, params, indices, expected):
result = ragged_gather_ops.gather_nd(params, indices)
self.assertRaggedEqual(result, expected)
self.assertAllEqual(result, expected)
def testRaggedGatherNdUnknownRankError(self):
if context.executing_eagerly():

View File

@ -26,12 +26,11 @@ from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_gather_ops
from tensorflow.python.ops.ragged import ragged_test_util
from tensorflow.python.platform import googletest
@test_util.run_all_in_graph_and_eager_modes
class RaggedGatherOpTest(ragged_test_util.RaggedTensorTestCase):
class RaggedGatherOpTest(test_util.TensorFlowTestCase):
def testDocStringExamples(self):
params = constant_op.constant(['a', 'b', 'c', 'd', 'e'])
@ -39,20 +38,20 @@ class RaggedGatherOpTest(ragged_test_util.RaggedTensorTestCase):
ragged_params = ragged_factory_ops.constant([['a', 'b', 'c'], ['d'], [],
['e']])
ragged_indices = ragged_factory_ops.constant([[3, 1, 2], [1], [], [0]])
self.assertRaggedEqual(
self.assertAllEqual(
ragged_gather_ops.gather(params, ragged_indices),
[[b'd', b'b', b'c'], [b'b'], [], [b'a']])
self.assertRaggedEqual(
self.assertAllEqual(
ragged_gather_ops.gather(ragged_params, indices),
[[b'e'], [b'd'], [], [b'd'], [b'a', b'b', b'c']])
self.assertRaggedEqual(
self.assertAllEqual(
ragged_gather_ops.gather(ragged_params, ragged_indices),
[[[b'e'], [b'd'], []], [[b'd']], [], [[b'a', b'b', b'c']]])
def testTensorParamsAndTensorIndices(self):
params = ['a', 'b', 'c', 'd', 'e']
indices = [2, 0, 2, 1]
self.assertRaggedEqual(
self.assertAllEqual(
ragged_gather_ops.gather(params, indices), [b'c', b'a', b'c', b'b'])
self.assertIsInstance(ragged_gather_ops.gather(params, indices), ops.Tensor)
@ -60,14 +59,14 @@ class RaggedGatherOpTest(ragged_test_util.RaggedTensorTestCase):
params = ragged_factory_ops.constant([['a', 'b'], ['c', 'd', 'e'], ['f'],
[], ['g']])
indices = [2, 0, 2, 1]
self.assertRaggedEqual(
self.assertAllEqual(
ragged_gather_ops.gather(params, indices),
[[b'f'], [b'a', b'b'], [b'f'], [b'c', b'd', b'e']])
def testTensorParamsAndRaggedIndices(self):
params = ['a', 'b', 'c', 'd', 'e']
indices = ragged_factory_ops.constant([[2, 1], [1, 2, 0], [3]])
self.assertRaggedEqual(
self.assertAllEqual(
ragged_gather_ops.gather(params, indices),
[[b'c', b'b'], [b'b', b'c', b'a'], [b'd']])
@ -75,7 +74,7 @@ class RaggedGatherOpTest(ragged_test_util.RaggedTensorTestCase):
params = ragged_factory_ops.constant([['a', 'b'], ['c', 'd', 'e'], ['f'],
[], ['g']])
indices = ragged_factory_ops.constant([[2, 1], [1, 2, 0], [3]])
self.assertRaggedEqual(
self.assertAllEqual(
ragged_gather_ops.gather(params, indices),
[[[b'f'], [b'c', b'd', b'e']], # [[p[2], p[1] ],
[[b'c', b'd', b'e'], [b'f'], [b'a', b'b']], # [p[1], p[2], p[0]],
@ -86,14 +85,14 @@ class RaggedGatherOpTest(ragged_test_util.RaggedTensorTestCase):
params = ragged_factory_ops.constant([['a', 'b'], ['c', 'd', 'e'], ['f'],
[], ['g']])
indices = 1
self.assertRaggedEqual(
self.assertAllEqual(
ragged_gather_ops.gather(params, indices), [b'c', b'd', b'e'])
def test3DRaggedParamsAnd2DTensorIndices(self):
params = ragged_factory_ops.constant([[['a', 'b'], []],
[['c', 'd'], ['e'], ['f']], [['g']]])
indices = [[1, 2], [0, 1], [2, 2]]
self.assertRaggedEqual(
self.assertAllEqual(
ragged_gather_ops.gather(params, indices),
[[[[b'c', b'd'], [b'e'], [b'f']], [[b'g']]], # [[p1, p2],
[[[b'a', b'b'], []], [[b'c', b'd'], [b'e'], [b'f']]], # [p0, p1],
@ -107,7 +106,7 @@ class RaggedGatherOpTest(ragged_test_util.RaggedTensorTestCase):
ragged_rank=2,
inner_shape=(2,))
params = ['a', 'b', 'c', 'd', 'e', 'f', 'g']
self.assertRaggedEqual(
self.assertAllEqual(
ragged_gather_ops.gather(params, indices),
[[[[b'd', b'e'], [b'a', b'g']], []],
[[[b'c', b'b'], [b'b', b'a']], [[b'c', b'f']], [[b'c', b'd']]],

View File

@ -27,12 +27,11 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_functional_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_test_util
from tensorflow.python.platform import googletest
@test_util.run_all_in_graph_and_eager_modes
class RaggedMapInnerValuesOpTest(ragged_test_util.RaggedTensorTestCase):
class RaggedMapInnerValuesOpTest(test_util.TensorFlowTestCase):
def assertRaggedMapInnerValuesReturns(self,
op,
@ -41,7 +40,7 @@ class RaggedMapInnerValuesOpTest(ragged_test_util.RaggedTensorTestCase):
kwargs=None):
kwargs = kwargs or {}
result = ragged_functional_ops.map_flat_values(op, *args, **kwargs)
self.assertRaggedEqual(result, expected)
self.assertAllEqual(result, expected)
def testDocStringExamples(self):
"""Test the examples in apply_op_to_ragged_values.__doc__."""
@ -49,9 +48,9 @@ class RaggedMapInnerValuesOpTest(ragged_test_util.RaggedTensorTestCase):
v1 = ragged_functional_ops.map_flat_values(array_ops.ones_like, rt)
v2 = ragged_functional_ops.map_flat_values(math_ops.multiply, rt, rt)
v3 = ragged_functional_ops.map_flat_values(math_ops.add, rt, 5)
self.assertRaggedEqual(v1, [[1, 1, 1], [], [1, 1], [1]])
self.assertRaggedEqual(v2, [[1, 4, 9], [], [16, 25], [36]])
self.assertRaggedEqual(v3, [[6, 7, 8], [], [9, 10], [11]])
self.assertAllEqual(v1, [[1, 1, 1], [], [1, 1], [1]])
self.assertAllEqual(v2, [[1, 4, 9], [], [16, 25], [36]])
self.assertAllEqual(v3, [[6, 7, 8], [], [9, 10], [11]])
def testOpWithSingleRaggedTensorArg(self):
tensor = ragged_factory_ops.constant([[1, 2, 3], [], [4, 5]])
@ -122,7 +121,7 @@ class RaggedMapInnerValuesOpTest(ragged_test_util.RaggedTensorTestCase):
# ragged_rank=0
x0 = [3, 1, 4, 1, 5, 9, 2, 6, 5]
y0 = [1, 2, 3, 4, 5, 6, 7, 8, 9]
self.assertRaggedEqual(
self.assertAllEqual(
math_ops.multiply(x0, y0), [3, 2, 12, 4, 25, 54, 14, 48, 45])
# ragged_rank=1

View File

@ -32,12 +32,11 @@ from tensorflow.python.ops.ragged import ragged_functional_ops
from tensorflow.python.ops.ragged import ragged_map_ops
from tensorflow.python.ops.ragged import ragged_math_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_test_util
from tensorflow.python.platform import googletest
@test_util.run_all_in_graph_and_eager_modes
class RaggedMapOpTest(ragged_test_util.RaggedTensorTestCase,
class RaggedMapOpTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
@parameterized.parameters([
@ -168,7 +167,7 @@ class RaggedMapOpTest(ragged_test_util.RaggedTensorTestCase,
expected_rt = ragged_factory_ops.constant(
expected_output, ragged_rank=expected_ragged_rank)
self.assertRaggedEqual(expected_rt, output)
self.assertAllEqual(expected_rt, output)
def testRaggedMapOnStructure(self):
batman = ragged_factory_ops.constant([[1, 2, 3], [4], [5, 6, 7]])
@ -186,7 +185,7 @@ class RaggedMapOpTest(ragged_test_util.RaggedTensorTestCase,
dtype=dtypes.int32,
)
self.assertRaggedEqual(output, [66, 44, 198])
self.assertAllEqual(output, [66, 44, 198])
# Test mapping over a dict of RTs can produce a dict of RTs.
def testRaggedMapOnStructure_RaggedOutputs(self):
@ -216,8 +215,8 @@ class RaggedMapOpTest(ragged_test_util.RaggedTensorTestCase,
},
)
self.assertRaggedEqual(output['batman'], [[2, 3, 4], [5], [6, 7, 8]])
self.assertRaggedEqual(output['robin'], [[11, 21, 31], [41], [51, 61, 71]])
self.assertAllEqual(output['batman'], [[2, 3, 4], [5], [6, 7, 8]])
self.assertAllEqual(output['robin'], [[11, 21, 31], [41], [51, 61, 71]])
def testZip(self):
x = ragged_factory_ops.constant(
@ -234,7 +233,7 @@ class RaggedMapOpTest(ragged_test_util.RaggedTensorTestCase,
dtype=ragged_tensor.RaggedTensorType(dtype=dtypes.int64, ragged_rank=1),
infer_shape=False)
self.assertRaggedEqual(
self.assertAllEqual(
output, [[[0, 10], [0, 20]], [[1, 30], [1, 40]], [[2, 50], [2, 60]],
[[3, 70]], [[4, 80], [4, 90], [4, 100]]])
@ -255,7 +254,7 @@ class RaggedMapOpTest(ragged_test_util.RaggedTensorTestCase,
dtype=dtypes.string, ragged_rank=1),
infer_shape=False)
self.assertRaggedEqual(
self.assertAllEqual(
out, [[b'hello', b'there'], [b'merhaba'], [b'bonjour', b'ca va']])
def testMismatchRaggedRank(self):
@ -290,7 +289,7 @@ class RaggedMapOpTest(ragged_test_util.RaggedTensorTestCase,
id_t2 = ragged_map_ops.map_fn(
lambda x: x, t2,
)
self.assertRaggedEqual(id_t2, [[0, 5], [0, 4]])
self.assertAllEqual(id_t2, [[0, 5], [0, 4]])
if __name__ == '__main__':

View File

@ -20,71 +20,70 @@ from __future__ import print_function
from tensorflow.python.framework import test_util
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_test_util
from tensorflow.python.platform import googletest
@test_util.run_all_in_graph_and_eager_modes
class RaggedElementwiseOpsTest(ragged_test_util.RaggedTensorTestCase):
class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase):
def testOrderingOperators(self):
x = ragged_factory_ops.constant([[1, 5], [3]])
y = ragged_factory_ops.constant([[4, 5], [1]])
self.assertRaggedEqual((x > y), [[False, False], [True]])
self.assertRaggedEqual((x >= y), [[False, True], [True]])
self.assertRaggedEqual((x < y), [[True, False], [False]])
self.assertRaggedEqual((x <= y), [[True, True], [False]])
self.assertAllEqual((x > y), [[False, False], [True]])
self.assertAllEqual((x >= y), [[False, True], [True]])
self.assertAllEqual((x < y), [[True, False], [False]])
self.assertAllEqual((x <= y), [[True, True], [False]])
def testArithmeticOperators(self):
x = ragged_factory_ops.constant([[1.0, -2.0], [8.0]])
y = ragged_factory_ops.constant([[4.0, 4.0], [2.0]])
self.assertRaggedEqual(abs(x), [[1.0, 2.0], [8.0]])
self.assertAllEqual(abs(x), [[1.0, 2.0], [8.0]])
self.assertRaggedEqual((-x), [[-1.0, 2.0], [-8.0]])
self.assertAllEqual((-x), [[-1.0, 2.0], [-8.0]])
self.assertRaggedEqual((x + y), [[5.0, 2.0], [10.0]])
self.assertRaggedEqual((3.0 + y), [[7.0, 7.0], [5.0]])
self.assertRaggedEqual((x + 3.0), [[4.0, 1.0], [11.0]])
self.assertAllEqual((x + y), [[5.0, 2.0], [10.0]])
self.assertAllEqual((3.0 + y), [[7.0, 7.0], [5.0]])
self.assertAllEqual((x + 3.0), [[4.0, 1.0], [11.0]])
self.assertRaggedEqual((x - y), [[-3.0, -6.0], [6.0]])
self.assertRaggedEqual((3.0 - y), [[-1.0, -1.0], [1.0]])
self.assertRaggedEqual((x + 3.0), [[4.0, 1.0], [11.0]])
self.assertAllEqual((x - y), [[-3.0, -6.0], [6.0]])
self.assertAllEqual((3.0 - y), [[-1.0, -1.0], [1.0]])
self.assertAllEqual((x + 3.0), [[4.0, 1.0], [11.0]])
self.assertRaggedEqual((x * y), [[4.0, -8.0], [16.0]])
self.assertRaggedEqual((3.0 * y), [[12.0, 12.0], [6.0]])
self.assertRaggedEqual((x * 3.0), [[3.0, -6.0], [24.0]])
self.assertAllEqual((x * y), [[4.0, -8.0], [16.0]])
self.assertAllEqual((3.0 * y), [[12.0, 12.0], [6.0]])
self.assertAllEqual((x * 3.0), [[3.0, -6.0], [24.0]])
self.assertRaggedEqual((x / y), [[0.25, -0.5], [4.0]])
self.assertRaggedEqual((y / x), [[4.0, -2.0], [0.25]])
self.assertRaggedEqual((2.0 / y), [[0.5, 0.5], [1.0]])
self.assertRaggedEqual((x / 2.0), [[0.5, -1.0], [4.0]])
self.assertAllEqual((x / y), [[0.25, -0.5], [4.0]])
self.assertAllEqual((y / x), [[4.0, -2.0], [0.25]])
self.assertAllEqual((2.0 / y), [[0.5, 0.5], [1.0]])
self.assertAllEqual((x / 2.0), [[0.5, -1.0], [4.0]])
self.assertRaggedEqual((x // y), [[0.0, -1.0], [4.0]])
self.assertRaggedEqual((y // x), [[4.0, -2.0], [0.0]])
self.assertRaggedEqual((2.0 // y), [[0.0, 0.0], [1.0]])
self.assertRaggedEqual((x // 2.0), [[0.0, -1.0], [4.0]])
self.assertAllEqual((x // y), [[0.0, -1.0], [4.0]])
self.assertAllEqual((y // x), [[4.0, -2.0], [0.0]])
self.assertAllEqual((2.0 // y), [[0.0, 0.0], [1.0]])
self.assertAllEqual((x // 2.0), [[0.0, -1.0], [4.0]])
self.assertRaggedEqual((x % y), [[1.0, 2.0], [0.0]])
self.assertRaggedEqual((y % x), [[0.0, -0.0], [2.0]])
self.assertRaggedEqual((2.0 % y), [[2.0, 2.0], [0.0]])
self.assertRaggedEqual((x % 2.0), [[1.0, 0.0], [0.0]])
self.assertAllEqual((x % y), [[1.0, 2.0], [0.0]])
self.assertAllEqual((y % x), [[0.0, -0.0], [2.0]])
self.assertAllEqual((2.0 % y), [[2.0, 2.0], [0.0]])
self.assertAllEqual((x % 2.0), [[1.0, 0.0], [0.0]])
def testLogicalOperators(self):
a = ragged_factory_ops.constant([[True, True], [False]])
b = ragged_factory_ops.constant([[True, False], [False]])
self.assertRaggedEqual((~a), [[False, False], [True]])
self.assertAllEqual((~a), [[False, False], [True]])
self.assertRaggedEqual((a & b), [[True, False], [False]])
self.assertRaggedEqual((a & True), [[True, True], [False]])
self.assertRaggedEqual((True & b), [[True, False], [False]])
self.assertAllEqual((a & b), [[True, False], [False]])
self.assertAllEqual((a & True), [[True, True], [False]])
self.assertAllEqual((True & b), [[True, False], [False]])
self.assertRaggedEqual((a | b), [[True, True], [False]])
self.assertRaggedEqual((a | False), [[True, True], [False]])
self.assertRaggedEqual((False | b), [[True, False], [False]])
self.assertAllEqual((a | b), [[True, True], [False]])
self.assertAllEqual((a | False), [[True, True], [False]])
self.assertAllEqual((False | b), [[True, False], [False]])
self.assertRaggedEqual((a ^ b), [[False, True], [False]])
self.assertRaggedEqual((a ^ True), [[False, False], [True]])
self.assertRaggedEqual((True ^ b), [[False, True], [True]])
self.assertAllEqual((a ^ b), [[False, True], [False]])
self.assertAllEqual((a ^ True), [[False, False], [True]])
self.assertAllEqual((True ^ b), [[False, True], [True]])
def testDummyOperators(self):
a = ragged_factory_ops.constant([[True, True], [False]])

View File

@ -23,12 +23,11 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_test_util
from tensorflow.python.platform import googletest
@test_util.run_all_in_graph_and_eager_modes
class RaggedPlaceholderOpTest(ragged_test_util.RaggedTensorTestCase,
class RaggedPlaceholderOpTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
@parameterized.parameters([

View File

@ -21,40 +21,39 @@ from __future__ import print_function
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.ops.ragged import ragged_math_ops
from tensorflow.python.ops.ragged import ragged_test_util
from tensorflow.python.platform import googletest
@test_util.run_all_in_graph_and_eager_modes
class RaggedRangeOpTest(ragged_test_util.RaggedTensorTestCase):
class RaggedRangeOpTest(test_util.TensorFlowTestCase):
def testDocStringExamples(self):
"""Examples from ragged_range.__doc__."""
rt1 = ragged_math_ops.range([3, 5, 2])
self.assertRaggedEqual(rt1, [[0, 1, 2], [0, 1, 2, 3, 4], [0, 1]])
self.assertAllEqual(rt1, [[0, 1, 2], [0, 1, 2, 3, 4], [0, 1]])
rt2 = ragged_math_ops.range([0, 5, 8], [3, 3, 12])
self.assertRaggedEqual(rt2, [[0, 1, 2], [], [8, 9, 10, 11]])
self.assertAllEqual(rt2, [[0, 1, 2], [], [8, 9, 10, 11]])
rt3 = ragged_math_ops.range([0, 5, 8], [3, 3, 12], 2)
self.assertRaggedEqual(rt3, [[0, 2], [], [8, 10]])
self.assertAllEqual(rt3, [[0, 2], [], [8, 10]])
def testBasicRanges(self):
# Specify limits only.
self.assertRaggedEqual(
self.assertAllEqual(
ragged_math_ops.range([0, 3, 5]),
[list(range(0)), list(range(3)),
list(range(5))])
# Specify starts and limits.
self.assertRaggedEqual(
self.assertAllEqual(
ragged_math_ops.range([0, 3, 5], [2, 3, 10]),
[list(range(0, 2)),
list(range(3, 3)),
list(range(5, 10))])
# Specify starts, limits, and deltas.
self.assertRaggedEqual(
self.assertAllEqual(
ragged_math_ops.range([0, 3, 5], [4, 4, 15], [2, 3, 4]),
[list(range(0, 4, 2)),
list(range(3, 4, 3)),
@ -65,18 +64,16 @@ class RaggedRangeOpTest(ragged_test_util.RaggedTensorTestCase):
[5.0, 7.2, 9.4, 11.6, 13.8]]
actual = ragged_math_ops.range([0.0, 3.0, 5.0], [3.9, 4.0, 15.0],
[0.4, 1.5, 2.2])
self.assertEqual(
expected,
[[round(v, 5) for v in row] for row in self.eval_to_list(actual)])
self.assertAllClose(actual, expected)
def testNegativeDeltas(self):
self.assertRaggedEqual(
self.assertAllEqual(
ragged_math_ops.range([0, 3, 5], limits=0, deltas=-1),
[list(range(0, 0, -1)),
list(range(3, 0, -1)),
list(range(5, 0, -1))])
self.assertRaggedEqual(
self.assertAllEqual(
ragged_math_ops.range([0, -3, 5], limits=0, deltas=[-1, 1, -2]),
[list(range(0, 0, -1)),
list(range(-3, 0, 1)),
@ -84,21 +81,21 @@ class RaggedRangeOpTest(ragged_test_util.RaggedTensorTestCase):
def testBroadcast(self):
# Specify starts and limits, broadcast deltas.
self.assertRaggedEqual(
self.assertAllEqual(
ragged_math_ops.range([0, 3, 5], [4, 4, 15], 3),
[list(range(0, 4, 3)),
list(range(3, 4, 3)),
list(range(5, 15, 3))])
# Broadcast all arguments.
self.assertRaggedEqual(
self.assertAllEqual(
ragged_math_ops.range(0, 5, 1), [list(range(0, 5, 1))])
def testEmptyRanges(self):
rt1 = ragged_math_ops.range([0, 5, 3], [0, 3, 5])
rt2 = ragged_math_ops.range([0, 5, 5], [0, 3, 5], -1)
self.assertRaggedEqual(rt1, [[], [], [3, 4]])
self.assertRaggedEqual(rt2, [[], [5, 4], []])
self.assertAllEqual(rt1, [[], [], [3, 4]])
self.assertAllEqual(rt2, [[], [5, 4], []])
def testShapeFnErrors(self):
self.assertRaises((ValueError, errors.InvalidArgumentError),
@ -116,11 +113,11 @@ class RaggedRangeOpTest(ragged_test_util.RaggedTensorTestCase):
self.evaluate(ragged_math_ops.range(0, 0, 0))
def testShape(self):
self.assertRaggedEqual(
self.assertAllEqual(
ragged_math_ops.range(0, 0, 1).shape.as_list(), [1, None])
self.assertRaggedEqual(
self.assertAllEqual(
ragged_math_ops.range([1, 2, 3]).shape.as_list(), [3, None])
self.assertRaggedEqual(
self.assertAllEqual(
ragged_math_ops.range([1, 2, 3], [4, 5, 6]).shape.as_list(), [3, None])

View File

@ -21,12 +21,11 @@ from absl.testing import parameterized
from tensorflow.python.framework import test_util
from tensorflow.python.ops.ragged import ragged_array_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_test_util
from tensorflow.python.platform import googletest
@test_util.run_all_in_graph_and_eager_modes
class RaggedRankOpTest(ragged_test_util.RaggedTensorTestCase,
class RaggedRankOpTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
@parameterized.parameters([

View File

@ -28,7 +28,6 @@ from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_math_ops
from tensorflow.python.ops.ragged import ragged_test_util
from tensorflow.python.platform import googletest
_MAX_INT32 = dtypes.int32.max
@ -41,7 +40,7 @@ def mean(*values):
@test_util.run_all_in_graph_and_eager_modes
class RaggedReduceOpsTest(ragged_test_util.RaggedTensorTestCase,
class RaggedReduceOpsTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
@parameterized.parameters(
@ -320,7 +319,7 @@ class RaggedReduceOpsTest(ragged_test_util.RaggedTensorTestCase,
def testReduce(self, ragged_reduce_op, rt_input, axis, expected):
rt_input = ragged_factory_ops.constant(rt_input)
reduced = ragged_reduce_op(rt_input, axis)
self.assertRaggedEqual(reduced, expected)
self.assertAllEqual(reduced, expected)
def assertEqualWithNan(self, actual, expected):
"""Like assertEqual, but NaN==NaN."""
@ -340,7 +339,7 @@ class RaggedReduceOpsTest(ragged_test_util.RaggedTensorTestCase,
tensor = [[1.0, 2.0, 3.0], [10.0, 20.0, 30.0]]
expected = [2.0, 20.0]
reduced = ragged_math_ops.reduce_mean(tensor, axis=1)
self.assertRaggedEqual(reduced, expected)
self.assertAllEqual(reduced, expected)
def testErrors(self):
rt_input = ragged_factory_ops.constant([[1, 2, 3], [4, 5]])

View File

@ -24,12 +24,11 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_test_util
from tensorflow.python.platform import googletest
@test_util.run_all_in_graph_and_eager_modes
class RaggedRowLengthsOp(ragged_test_util.RaggedTensorTestCase,
class RaggedRowLengthsOp(test_util.TensorFlowTestCase,
parameterized.TestCase):
@parameterized.parameters([
@ -120,7 +119,7 @@ class RaggedRowLengthsOp(ragged_test_util.RaggedTensorTestCase,
expected_ragged_rank=None):
rt = ragged_factory_ops.constant(rt_input, ragged_rank=ragged_rank)
lengths = rt.row_lengths(axis)
self.assertRaggedEqual(lengths, expected)
self.assertAllEqual(lengths, expected)
if expected_ragged_rank is not None:
if isinstance(lengths, ragged_tensor.RaggedTensor):
self.assertEqual(lengths.ragged_rank, expected_ragged_rank)

View File

@ -20,13 +20,12 @@ from __future__ import print_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util
from tensorflow.python.ops.ragged import ragged_test_util
from tensorflow.python.ops.ragged import segment_id_ops
from tensorflow.python.platform import googletest
@test_util.run_all_in_graph_and_eager_modes
class RaggedSplitsToSegmentIdsOpTest(ragged_test_util.RaggedTensorTestCase):
class RaggedSplitsToSegmentIdsOpTest(test_util.TensorFlowTestCase):
def testDocStringExample(self):
splits = [0, 3, 3, 5, 6, 9]

View File

@ -20,13 +20,12 @@ from __future__ import print_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util
from tensorflow.python.ops.ragged import ragged_test_util
from tensorflow.python.ops.ragged import segment_id_ops
from tensorflow.python.platform import googletest
@test_util.run_all_in_graph_and_eager_modes
class RaggedSplitsToSegmentIdsOpTest(ragged_test_util.RaggedTensorTestCase):
class RaggedSplitsToSegmentIdsOpTest(test_util.TensorFlowTestCase):
def testDocStringExample(self):
segment_ids = [0, 0, 0, 2, 2, 3, 4, 4, 4]

View File

@ -28,7 +28,6 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_math_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_test_util
from tensorflow.python.platform import googletest
@ -49,7 +48,7 @@ def sqrt_n(values):
@test_util.run_all_in_graph_and_eager_modes
class RaggedSegmentOpsTest(ragged_test_util.RaggedTensorTestCase,
class RaggedSegmentOpsTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
def expected_value(self, data, segment_ids, num_segments, combiner):
@ -110,7 +109,7 @@ class RaggedSegmentOpsTest(ragged_test_util.RaggedTensorTestCase,
combiner)
segmented = segment_op(rt, segment_ids, num_segments)
self.assertRaggedEqual(segmented, expected)
self.assertAllEqual(segmented, expected)
@parameterized.parameters(
(ragged_math_ops.segment_sum, sum, [0, 0, 1, 1, 2, 2]),
@ -146,7 +145,7 @@ class RaggedSegmentOpsTest(ragged_test_util.RaggedTensorTestCase,
combiner)
segmented = segment_op(rt, segment_ids, num_segments)
self.assertRaggedAlmostEqual(segmented, expected, places=5)
self.assertAllClose(segmented, expected)
def testRaggedRankTwo(self):
rt = ragged_factory_ops.constant([
@ -161,14 +160,14 @@ class RaggedSegmentOpsTest(ragged_test_util.RaggedTensorTestCase,
[], # row 1
[[411, 412], [321, 322], [331]] # row 2
] # pyformat: disable
self.assertRaggedEqual(segmented1, expected1)
self.assertAllEqual(segmented1, expected1)
segment_ids2 = [1, 2, 1, 1]
segmented2 = ragged_math_ops.segment_sum(rt, segment_ids2, 3)
expected2 = [[],
[[111+411, 112+412, 113, 114], [121+321, 322], [331]],
[]] # pyformat: disable
self.assertRaggedEqual(segmented2, expected2)
self.assertAllEqual(segmented2, expected2)
def testRaggedSegmentIds(self):
rt = ragged_factory_ops.constant([
@ -182,7 +181,7 @@ class RaggedSegmentOpsTest(ragged_test_util.RaggedTensorTestCase,
expected = [[],
[111+321, 112+322, 113, 114],
[121+331+411, 412]] # pyformat: disable
self.assertRaggedEqual(segmented, expected)
self.assertAllEqual(segmented, expected)
def testShapeMismatchError1(self):
dt = constant_op.constant([1, 2, 3, 4, 5, 6])

View File

@ -23,12 +23,11 @@ from absl.testing import parameterized
from tensorflow.python.framework import test_util
from tensorflow.python.ops.ragged import ragged_array_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_test_util
from tensorflow.python.platform import googletest
@test_util.run_all_in_graph_and_eager_modes
class RaggedSizeOpTest(ragged_test_util.RaggedTensorTestCase,
class RaggedSizeOpTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
@parameterized.parameters([

View File

@ -27,12 +27,11 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops.ragged import ragged_conversion_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_squeeze_op
from tensorflow.python.ops.ragged import ragged_test_util
from tensorflow.python.platform import googletest
@test_util.run_all_in_graph_and_eager_modes
class RaggedSqueezeTest(ragged_test_util.RaggedTensorTestCase,
class RaggedSqueezeTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
@parameterized.parameters([
@ -52,7 +51,7 @@ class RaggedSqueezeTest(ragged_test_util.RaggedTensorTestCase,
rt = ragged_squeeze_op.squeeze(
ragged_factory_ops.constant(input_list), squeeze_ranks)
dt = array_ops.squeeze(constant_op.constant(input_list), squeeze_ranks)
self.assertRaggedEqual(ragged_conversion_ops.to_tensor(rt), dt)
self.assertAllEqual(ragged_conversion_ops.to_tensor(rt), dt)
@parameterized.parameters([
{
@ -112,7 +111,7 @@ class RaggedSqueezeTest(ragged_test_util.RaggedTensorTestCase,
rt = ragged_squeeze_op.squeeze(
ragged_factory_ops.constant(input_list), squeeze_ranks)
dt = array_ops.squeeze(constant_op.constant(input_list), squeeze_ranks)
self.assertRaggedEqual(ragged_conversion_ops.to_tensor(rt), dt)
self.assertAllEqual(ragged_conversion_ops.to_tensor(rt), dt)
@parameterized.parameters([
# ragged_conversion_ops.from_tensor does not work for this
@ -167,7 +166,7 @@ class RaggedSqueezeTest(ragged_test_util.RaggedTensorTestCase,
rt = ragged_conversion_ops.from_tensor(dt)
rt_s = ragged_squeeze_op.squeeze(rt, squeeze_ranks)
dt_s = array_ops.squeeze(dt, squeeze_ranks)
self.assertRaggedEqual(ragged_conversion_ops.to_tensor(rt_s), dt_s)
self.assertAllEqual(ragged_conversion_ops.to_tensor(rt_s), dt_s)
@parameterized.parameters([
{
@ -185,7 +184,7 @@ class RaggedSqueezeTest(ragged_test_util.RaggedTensorTestCase,
rt = ragged_factory_ops.constant(input_list)
rt_s = ragged_squeeze_op.squeeze(rt, squeeze_ranks)
ref = ragged_factory_ops.constant(output_list)
self.assertRaggedEqual(rt_s, ref)
self.assertAllEqual(rt_s, ref)
def test_passing_text(self):
rt = ragged_factory_ops.constant([[[[[[[['H']], [['e']], [['l']], [['l']],
@ -202,7 +201,7 @@ class RaggedSqueezeTest(ragged_test_util.RaggedTensorTestCase,
['M', 'e', 'h', 'r', 'd', 'a', 'd'], ['.']]]
ref = ragged_factory_ops.constant(output_list)
rt_s = ragged_squeeze_op.squeeze(rt, [0, 1, 3, 6, 7])
self.assertRaggedEqual(rt_s, ref)
self.assertAllEqual(rt_s, ref)
@parameterized.parameters([
{

View File

@ -24,12 +24,11 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util
from tensorflow.python.ops.ragged import ragged_concat_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_test_util
from tensorflow.python.platform import googletest
@test_util.run_all_in_graph_and_eager_modes
class RaggedStackOpTest(ragged_test_util.RaggedTensorTestCase,
class RaggedStackOpTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
@parameterized.parameters(
@ -335,7 +334,7 @@ class RaggedStackOpTest(ragged_test_util.RaggedTensorTestCase,
self.assertEqual(stacked.ragged_rank, expected_ragged_rank)
if expected_shape is not None:
self.assertEqual(stacked.shape.as_list(), expected_shape)
self.assertRaggedEqual(stacked, expected)
self.assertAllEqual(stacked, expected)
@parameterized.parameters(
dict(
@ -372,7 +371,7 @@ class RaggedStackOpTest(ragged_test_util.RaggedTensorTestCase,
"""
rt_inputs = ragged_factory_ops.constant([[1, 2], [3, 4]])
stacked = ragged_concat_ops.stack(rt_inputs, 0)
self.assertRaggedEqual(stacked, [[[1, 2], [3, 4]]])
self.assertAllEqual(stacked, [[[1, 2], [3, 4]]])
if __name__ == '__main__':

View File

@ -21,43 +21,42 @@ from __future__ import print_function
from tensorflow.python.framework import test_util
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_test_util
from tensorflow.python.platform import googletest
@test_util.run_all_in_graph_and_eager_modes
class RaggedTensorBoundingShapeOp(ragged_test_util.RaggedTensorTestCase):
class RaggedTensorBoundingShapeOp(test_util.TensorFlowTestCase):
def testDocStringExample(self):
# This is the example from ragged.bounding_shape.__doc__.
rt = ragged_factory_ops.constant([[1, 2, 3, 4], [5], [], [6, 7, 8, 9],
[10]])
self.assertRaggedEqual(rt.bounding_shape(), [5, 4])
self.assertAllEqual(rt.bounding_shape(), [5, 4])
def test2DRaggedTensorWithOneRaggedDimension(self):
values = ['a', 'b', 'c', 'd', 'e', 'f', 'g']
rt1 = ragged_tensor.RaggedTensor.from_row_splits(values, [0, 2, 5, 6, 6, 7])
rt2 = ragged_tensor.RaggedTensor.from_row_splits(values, [0, 7])
rt3 = ragged_tensor.RaggedTensor.from_row_splits(values, [0, 0, 7, 7])
self.assertRaggedEqual(rt1.bounding_shape(), [5, 3])
self.assertRaggedEqual(rt2.bounding_shape(), [1, 7])
self.assertRaggedEqual(rt3.bounding_shape(), [3, 7])
self.assertAllEqual(rt1.bounding_shape(), [5, 3])
self.assertAllEqual(rt2.bounding_shape(), [1, 7])
self.assertAllEqual(rt3.bounding_shape(), [3, 7])
def test3DRaggedTensorWithOneRaggedDimension(self):
values = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13]]
rt1 = ragged_tensor.RaggedTensor.from_row_splits(values, [0, 2, 5, 6, 6, 7])
rt2 = ragged_tensor.RaggedTensor.from_row_splits(values, [0, 7])
rt3 = ragged_tensor.RaggedTensor.from_row_splits(values, [0, 0, 7, 7])
self.assertRaggedEqual(rt1.bounding_shape(), [5, 3, 2])
self.assertRaggedEqual(rt2.bounding_shape(), [1, 7, 2])
self.assertRaggedEqual(rt3.bounding_shape(), [3, 7, 2])
self.assertAllEqual(rt1.bounding_shape(), [5, 3, 2])
self.assertAllEqual(rt2.bounding_shape(), [1, 7, 2])
self.assertAllEqual(rt3.bounding_shape(), [3, 7, 2])
def testExplicitAxisOptimizations(self):
rt = ragged_tensor.RaggedTensor.from_row_splits(b'a b c d e f g'.split(),
[0, 2, 5, 6, 6, 7])
self.assertRaggedEqual(rt.bounding_shape(0), 5)
self.assertRaggedEqual(rt.bounding_shape(1), 3)
self.assertRaggedEqual(rt.bounding_shape([1, 0]), [3, 5])
self.assertAllEqual(rt.bounding_shape(0), 5)
self.assertAllEqual(rt.bounding_shape(1), 3)
self.assertAllEqual(rt.bounding_shape([1, 0]), [3, 5])
if __name__ == '__main__':

View File

@ -26,27 +26,20 @@ from tensorflow.python.framework import test_util
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_tensor_shape
from tensorflow.python.ops.ragged import ragged_test_util
from tensorflow.python.ops.ragged.ragged_tensor_shape import RaggedTensorDynamicShape
from tensorflow.python.platform import googletest
@test_util.run_all_in_graph_and_eager_modes
class RaggedTensorShapeTest(ragged_test_util.RaggedTensorTestCase,
class RaggedTensorShapeTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
def assertShapeEq(self, x, y):
assert isinstance(x, RaggedTensorDynamicShape)
assert isinstance(y, RaggedTensorDynamicShape)
x_partitioned_dim_sizes = [
self.eval_to_list(splits) #
for splits in x.partitioned_dim_sizes
]
y_partitioned_dim_sizes = [
self.eval_to_list(splits) #
for splits in y.partitioned_dim_sizes
]
self.assertEqual(x_partitioned_dim_sizes, y_partitioned_dim_sizes)
self.assertLen(x.partitioned_dim_sizes, len(y.partitioned_dim_sizes))
for x_dims, y_dims in zip(x.partitioned_dim_sizes, y.partitioned_dim_sizes):
self.assertAllEqual(x_dims, y_dims)
self.assertAllEqual(x.inner_dim_sizes, y.inner_dim_sizes)
@parameterized.parameters([
@ -422,7 +415,7 @@ class RaggedTensorShapeTest(ragged_test_util.RaggedTensorTestCase,
result = ragged_tensor_shape.broadcast_to(x, shape)
self.assertEqual(
getattr(result, 'ragged_rank', 0), getattr(expected, 'ragged_rank', 0))
self.assertRaggedEqual(result, expected)
self.assertAllEqual(result, expected)
@parameterized.parameters(
[
@ -484,7 +477,7 @@ class RaggedTensorShapeTest(ragged_test_util.RaggedTensorTestCase,
self.assertEqual(expected_rrank, result_rrank)
if hasattr(expected, 'tolist'):
expected = expected.tolist()
self.assertRaggedEqual(result, expected)
self.assertAllEqual(result, expected)
if __name__ == '__main__':

View File

@ -34,7 +34,6 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_math_ops
from tensorflow.python.ops.ragged import ragged_tensor_value
from tensorflow.python.ops.ragged import ragged_test_util
from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensor
from tensorflow.python.platform import googletest
@ -114,7 +113,7 @@ EXAMPLE_RAGGED_TENSOR_4D_VALUES = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10],
@test_util.run_all_in_graph_and_eager_modes
class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
class RaggedTensorTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
longMessage = True # Property in unittest.Testcase. pylint: disable=invalid-name
@ -126,7 +125,7 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
# From section: "Component Tensors"
rt = RaggedTensor.from_row_splits(
values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8])
self.assertRaggedEqual(rt, [[3, 1, 4, 1], [], [5, 9, 2], [6], []])
self.assertAllEqual(rt, [[3, 1, 4, 1], [], [5, 9, 2], [6], []])
del rt
# From section: "Alternative Row-Partitioning Schemes"
@ -138,7 +137,7 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
rt4 = RaggedTensor.from_row_starts(values, row_starts=[0, 4, 4, 7, 8])
rt5 = RaggedTensor.from_row_limits(values, row_limits=[4, 4, 7, 8, 8])
for rt in (rt1, rt2, rt3, rt4, rt5):
self.assertRaggedEqual(rt, [[3, 1, 4, 1], [], [5, 9, 2], [6], []])
self.assertAllEqual(rt, [[3, 1, 4, 1], [], [5, 9, 2], [6], []])
del rt1, rt2, rt3, rt4, rt5
# From section: "Multiple Ragged Dimensions"
@ -147,8 +146,8 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
outer_rt = RaggedTensor.from_row_splits(
values=inner_rt, row_splits=[0, 3, 3, 5])
self.assertEqual(outer_rt.ragged_rank, 2)
self.assertEqual(
self.eval_to_list(outer_rt),
self.assertAllEqual(
outer_rt,
[[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]])
del inner_rt, outer_rt
@ -156,15 +155,15 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
rt = RaggedTensor.from_nested_row_splits(
flat_values=[3, 1, 4, 1, 5, 9, 2, 6],
nested_row_splits=([0, 3, 3, 5], [0, 4, 4, 7, 8, 8]))
self.assertEqual(
self.eval_to_list(rt), [[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]])
self.assertAllEqual(
rt, [[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]])
del rt
# From section: "Uniform Inner Dimensions"
rt = RaggedTensor.from_row_splits(
values=array_ops.ones([5, 3]), row_splits=[0, 2, 5])
self.assertEqual(
self.eval_to_list(rt),
self.assertAllEqual(
rt,
[[[1, 1, 1], [1, 1, 1]], [[1, 1, 1], [1, 1, 1], [1, 1, 1]]])
self.assertEqual(rt.shape.as_list(), [2, None, 3])
del rt
@ -211,8 +210,8 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
rt = RaggedTensor(values=values, row_splits=row_splits, internal=True)
self.assertEqual(
self.eval_to_list(rt),
self.assertAllEqual(
rt,
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
def testRaggedTensorConstructionErrors(self):
@ -267,9 +266,9 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
self.assertIs(rt_values, values)
self.assertIs(rt_value_rowids, value_rowids) # cached_value_rowids
self.assertAllEqual(rt_value_rowids, value_rowids)
self.assertEqual(self.eval_to_list(rt_nrows), 5)
self.assertEqual(
self.eval_to_list(rt),
self.assertAllEqual(rt_nrows, 5)
self.assertAllEqual(
rt,
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
def testFromValueRowIdsWithDerivedNRowsDynamic(self):
@ -293,9 +292,9 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
self.assertIs(rt_values, values)
self.assertIs(rt_value_rowids, value_rowids) # cached_value_rowids
self.assertAllEqual(rt_value_rowids, value_rowids)
self.assertEqual(self.eval_to_list(rt_nrows), 5)
self.assertEqual(
self.eval_to_list(rt),
self.assertAllEqual(rt_nrows, 5)
self.assertAllEqual(
rt,
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
def testFromValueRowIdsWithExplicitNRows(self):
@ -316,8 +315,8 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
self.assertIs(rt_values, values)
self.assertIs(rt_value_rowids, value_rowids) # cached_value_rowids
self.assertIs(rt_nrows, nrows) # cached_nrows
self.assertEqual(
self.eval_to_list(rt),
self.assertAllEqual(
rt,
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g'], [], []])
def testFromValueRowIdsWithExplicitNRowsEqualToDefault(self):
@ -340,8 +339,8 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
self.assertIs(rt_nrows, nrows) # cached_nrows
self.assertAllEqual(rt_value_rowids, value_rowids)
self.assertAllEqual(rt_nrows, nrows)
self.assertEqual(
self.eval_to_list(rt),
self.assertAllEqual(
rt,
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
def testFromValueRowIdsWithEmptyValues(self):
@ -352,8 +351,8 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
self.assertEqual(rt.ragged_rank, 1)
self.assertEqual(rt.values.shape.as_list(), [0])
self.assertEqual(rt.value_rowids().shape.as_list(), [0])
self.assertEqual(self.eval_to_list(rt_nrows), 0)
self.assertEqual(self.eval_to_list(rt), [])
self.assertAllEqual(rt_nrows, 0)
self.assertAllEqual(rt, [])
def testFromRowSplits(self):
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
@ -370,9 +369,9 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
self.assertIs(rt_values, values)
self.assertIs(rt_row_splits, row_splits)
self.assertEqual(self.eval_to_list(rt_nrows), 5)
self.assertEqual(
self.eval_to_list(rt),
self.assertAllEqual(rt_nrows, 5)
self.assertAllEqual(
rt,
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
def testFromRowSplitsWithEmptySplits(self):
@ -394,10 +393,10 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
rt_nrows = rt.nrows()
self.assertIs(rt_values, values)
self.assertEqual(self.eval_to_list(rt_nrows), 5)
self.assertAllEqual(rt_nrows, 5)
self.assertAllEqual(rt_row_starts, row_starts)
self.assertEqual(
self.eval_to_list(rt),
self.assertAllEqual(
rt,
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
def testFromRowLimits(self):
@ -414,10 +413,10 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
rt_nrows = rt.nrows()
self.assertIs(rt_values, values)
self.assertEqual(self.eval_to_list(rt_nrows), 5)
self.assertAllEqual(rt_nrows, 5)
self.assertAllEqual(rt_row_limits, row_limits)
self.assertEqual(
self.eval_to_list(rt),
self.assertAllEqual(
rt,
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
def testFromRowLengths(self):
@ -435,10 +434,10 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
self.assertIs(rt_values, values)
self.assertIs(rt_row_lengths, row_lengths) # cached_nrows
self.assertEqual(self.eval_to_list(rt_nrows), 5)
self.assertAllEqual(rt_nrows, 5)
self.assertAllEqual(rt_row_lengths, row_lengths)
self.assertEqual(
self.eval_to_list(rt),
self.assertAllEqual(
rt,
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
def testFromNestedValueRowIdsWithDerivedNRows(self):
@ -461,8 +460,8 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
self.assertIs(rt_values_values, values)
self.assertAllEqual(rt_value_rowids, nested_value_rowids[0])
self.assertAllEqual(rt_values_value_rowids, nested_value_rowids[1])
self.assertEqual(
self.eval_to_list(rt),
self.assertAllEqual(
rt,
[[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]])
def testFromNestedValueRowIdsWithExplicitNRows(self):
@ -494,9 +493,9 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
self.assertAllEqual(rt_values_value_rowids, nested_value_rowids[1])
self.assertAllEqual(rt_nrows, nrows[0])
self.assertAllEqual(rt_values_nrows, nrows[1])
self.assertEqual(
self.eval_to_list(rt), [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [],
[[b'f'], [b'g'], []], [], []])
self.assertAllEqual(
rt, [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [],
[[b'f'], [b'g'], []], [], []])
def testFromNestedValueRowIdsWithExplicitNRowsMismatch(self):
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
@ -541,8 +540,8 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
self.assertIs(rt_values_values, flat_values)
self.assertIs(rt_row_splits, nested_row_splits[0])
self.assertIs(rt_values_row_splits, nested_row_splits[1])
self.assertEqual(
self.eval_to_list(rt),
self.assertAllEqual(
rt,
[[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]])
def testFromNestedRowSplitsWithNonListInput(self):
@ -609,7 +608,7 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
rt2 = RaggedTensor.from_value_rowids(values, value_rowids)
for rt in [rt1, rt2]:
self.assertRaggedEqual(
self.assertAllEqual(
rt, [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
self.assertAllEqual(rt.values, [b'a', b'b', b'c', b'd', b'e', b'f', b'g'])
self.assertEqual(rt.values.shape.dims[0].value, 7)
@ -632,26 +631,26 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
rt2 = RaggedTensor.from_value_rowids(values, value_rowids)
for rt in [rt1, rt2]:
self.assertEqual(
self.eval_to_list(rt),
self.assertAllEqual(
rt,
[[[0, 1], [2, 3]], [], [[4, 5], [6, 7], [8, 9]], [[10, 11]],
[[12, 13]]])
self.assertEqual(
self.eval_to_list(rt.values),
self.assertAllEqual(
rt.values,
[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13]])
self.assertEqual(rt.values.shape.dims[0].value, 7)
self.assertEqual(
self.eval_to_list(rt.value_rowids()), [0, 0, 2, 2, 2, 3, 4])
self.assertEqual(self.eval_to_list(rt.nrows()), 5)
self.assertEqual(self.eval_to_list(rt.row_splits), [0, 2, 2, 5, 6, 7])
self.assertEqual(self.eval_to_list(rt.row_starts()), [0, 2, 2, 5, 6])
self.assertEqual(self.eval_to_list(rt.row_limits()), [2, 2, 5, 6, 7])
self.assertEqual(self.eval_to_list(rt.row_lengths()), [2, 0, 3, 1, 1])
self.assertEqual(
self.eval_to_list(rt.flat_values),
self.assertAllEqual(
rt.value_rowids(), [0, 0, 2, 2, 2, 3, 4])
self.assertAllEqual(rt.nrows(), 5)
self.assertAllEqual(rt.row_splits, [0, 2, 2, 5, 6, 7])
self.assertAllEqual(rt.row_starts(), [0, 2, 2, 5, 6])
self.assertAllEqual(rt.row_limits(), [2, 2, 5, 6, 7])
self.assertAllEqual(rt.row_lengths(), [2, 0, 3, 1, 1])
self.assertAllEqual(
rt.flat_values,
[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13]])
self.assertEqual([self.eval_to_list(s) for s in rt.nested_row_splits],
[[0, 2, 2, 5, 6, 7]])
self.assertLen(rt.nested_row_splits, 1)
self.assertAllEqual(rt.nested_row_splits[0], [0, 2, 2, 5, 6, 7])
def testRaggedTensorAccessors_3d_with_ragged_rank_2(self):
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
@ -667,24 +666,25 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
rt2 = RaggedTensor.from_nested_value_rowids(values, nested_value_rowids)
for rt in [rt1, rt2]:
self.assertEqual(
self.eval_to_list(rt),
self.assertAllEqual(
rt,
[[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]])
self.assertEqual(
self.eval_to_list(rt.values),
self.assertAllEqual(
rt.values,
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
self.assertEqual(rt.values.shape.dims[0].value, 5)
self.assertEqual(self.eval_to_list(rt.value_rowids()), [0, 0, 1, 3, 3])
self.assertEqual(self.eval_to_list(rt.nrows()), 4)
self.assertEqual(self.eval_to_list(rt.row_splits), [0, 2, 3, 3, 5])
self.assertEqual(self.eval_to_list(rt.row_starts()), [0, 2, 3, 3])
self.assertEqual(self.eval_to_list(rt.row_limits()), [2, 3, 3, 5])
self.assertEqual(self.eval_to_list(rt.row_lengths()), [2, 1, 0, 2])
self.assertEqual(
self.eval_to_list(rt.flat_values),
self.assertAllEqual(rt.value_rowids(), [0, 0, 1, 3, 3])
self.assertAllEqual(rt.nrows(), 4)
self.assertAllEqual(rt.row_splits, [0, 2, 3, 3, 5])
self.assertAllEqual(rt.row_starts(), [0, 2, 3, 3])
self.assertAllEqual(rt.row_limits(), [2, 3, 3, 5])
self.assertAllEqual(rt.row_lengths(), [2, 1, 0, 2])
self.assertAllEqual(
rt.flat_values,
[b'a', b'b', b'c', b'd', b'e', b'f', b'g'])
self.assertEqual([self.eval_to_list(s) for s in rt.nested_row_splits],
[[0, 2, 3, 3, 5], [0, 2, 2, 5, 6, 7]])
self.assertLen(rt.nested_row_splits, 2)
self.assertAllEqual(rt.nested_row_splits[0], [0, 2, 3, 3, 5])
self.assertAllEqual(rt.nested_row_splits[1], [0, 2, 2, 5, 6, 7])
#=============================================================================
# RaggedTensor.shape
@ -742,12 +742,12 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
"""
tensor_slice_spec1 = _make_tensor_slice_spec(slice_spec, True)
tensor_slice_spec2 = _make_tensor_slice_spec(slice_spec, False)
value1 = self.eval_to_list(rt.__getitem__(slice_spec))
value2 = self.eval_to_list(rt.__getitem__(tensor_slice_spec1))
value3 = self.eval_to_list(rt.__getitem__(tensor_slice_spec2))
self.assertEqual(value1, expected, 'slice_spec=%s' % (slice_spec,))
self.assertEqual(value2, expected, 'slice_spec=%s' % (slice_spec,))
self.assertEqual(value3, expected, 'slice_spec=%s' % (slice_spec,))
value1 = rt.__getitem__(slice_spec)
value2 = rt.__getitem__(tensor_slice_spec1)
value3 = rt.__getitem__(tensor_slice_spec2)
self.assertAllEqual(value1, expected, 'slice_spec=%s' % (slice_spec,))
self.assertAllEqual(value2, expected, 'slice_spec=%s' % (slice_spec,))
self.assertAllEqual(value3, expected, 'slice_spec=%s' % (slice_spec,))
def _TestGetItemException(self, rt, slice_spec, expected, message):
"""Helper function for testing RaggedTensor.__getitem__ exceptions."""
@ -826,7 +826,7 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
rt = RaggedTensor.from_row_splits(EXAMPLE_RAGGED_TENSOR_2D_VALUES,
EXAMPLE_RAGGED_TENSOR_2D_SPLITS)
self.assertEqual(self.eval_to_list(rt), EXAMPLE_RAGGED_TENSOR_2D)
self.assertAllEqual(rt, EXAMPLE_RAGGED_TENSOR_2D)
self._TestGetItem(rt, slice_spec, expected)
# pylint: disable=invalid-slice-index
@ -872,7 +872,7 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
rt = RaggedTensor.from_row_splits(EXAMPLE_RAGGED_TENSOR_2D_VALUES,
EXAMPLE_RAGGED_TENSOR_2D_SPLITS)
self.assertEqual(self.eval_to_list(rt), EXAMPLE_RAGGED_TENSOR_2D)
self.assertAllEqual(rt, EXAMPLE_RAGGED_TENSOR_2D)
self._TestGetItemException(rt, slice_spec, expected, message)
@parameterized.parameters(
@ -946,7 +946,7 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
rt = RaggedTensor.from_nested_row_splits(
EXAMPLE_RAGGED_TENSOR_4D_VALUES,
[EXAMPLE_RAGGED_TENSOR_4D_SPLITS1, EXAMPLE_RAGGED_TENSOR_4D_SPLITS2])
self.assertEqual(self.eval_to_list(rt), EXAMPLE_RAGGED_TENSOR_4D)
self.assertAllEqual(rt, EXAMPLE_RAGGED_TENSOR_4D)
self._TestGetItem(rt, slice_spec, expected)
@parameterized.parameters(
@ -973,7 +973,7 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
rt = RaggedTensor.from_nested_row_splits(
EXAMPLE_RAGGED_TENSOR_4D_VALUES,
[EXAMPLE_RAGGED_TENSOR_4D_SPLITS1, EXAMPLE_RAGGED_TENSOR_4D_SPLITS2])
self.assertEqual(self.eval_to_list(rt), EXAMPLE_RAGGED_TENSOR_4D)
self.assertAllEqual(rt, EXAMPLE_RAGGED_TENSOR_4D)
self._TestGetItemException(rt, slice_spec, expected, message)
@parameterized.parameters(
@ -1015,7 +1015,7 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
EXAMPLE_RAGGED_TENSOR_2D_SPLITS, dtype=dtypes.int64)
splits = array_ops.placeholder_with_default(splits, None)
rt = RaggedTensor.from_row_splits(EXAMPLE_RAGGED_TENSOR_2D_VALUES, splits)
self.assertEqual(self.eval_to_list(rt), EXAMPLE_RAGGED_TENSOR_2D)
self.assertAllEqual(rt, EXAMPLE_RAGGED_TENSOR_2D)
self._TestGetItem(rt, slice_spec, expected)
@parameterized.parameters(
@ -1042,23 +1042,23 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
rt_newaxis3 = rt[:, :, :, array_ops.newaxis]
rt_newaxis4 = rt[:, :, :, :, array_ops.newaxis]
self.assertEqual(
self.eval_to_list(rt),
self.assertAllEqual(
rt,
[[[[b'a', b'b'], [b'c', b'd']], [], [[b'e', b'f']]], []])
self.assertEqual(
self.eval_to_list(rt_newaxis0),
self.assertAllEqual(
rt_newaxis0,
[[[[[b'a', b'b'], [b'c', b'd']], [], [[b'e', b'f']]], []]])
self.assertEqual(
self.eval_to_list(rt_newaxis1),
self.assertAllEqual(
rt_newaxis1,
[[[[[b'a', b'b'], [b'c', b'd']], [], [[b'e', b'f']]]], [[]]])
self.assertEqual(
self.eval_to_list(rt_newaxis2),
self.assertAllEqual(
rt_newaxis2,
[[[[[b'a', b'b'], [b'c', b'd']]], [[]], [[[b'e', b'f']]]], []])
self.assertEqual(
self.eval_to_list(rt_newaxis3),
self.assertAllEqual(
rt_newaxis3,
[[[[[b'a', b'b']], [[b'c', b'd']]], [], [[[b'e', b'f']]]], []])
self.assertEqual(
self.eval_to_list(rt_newaxis4),
self.assertAllEqual(
rt_newaxis4,
[[[[[b'a'], [b'b']], [[b'c'], [b'd']]], [], [[[b'e'], [b'f']]]], []])
self.assertEqual(rt.ragged_rank, 2)
@ -1121,14 +1121,14 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
rt2_times_10 = rt2.with_flat_values(rt2.flat_values * 10)
rt1_expanded = rt1.with_values(array_ops.expand_dims(rt1.values, axis=1))
self.assertEqual(
self.eval_to_list(rt1_plus_10),
self.assertAllEqual(
rt1_plus_10,
[[11, 12], [13, 14, 15], [16], [], [17]])
self.assertEqual(
self.eval_to_list(rt2_times_10),
self.assertAllEqual(
rt2_times_10,
[[[10, 20], [30, 40, 50]], [[60]], [], [[], [70]]])
self.assertEqual(
self.eval_to_list(rt1_expanded),
self.assertAllEqual(
rt1_expanded,
[[[1], [2]], [[3], [4], [5]], [[6]], [], [[7]]])
#=============================================================================
@ -1494,7 +1494,7 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
et = rt._to_variant()
round_trip_rt = RaggedTensor._from_variant(
et, dtype, output_ragged_rank=ragged_rank)
self.assertRaggedEqual(rt, round_trip_rt)
self.assertAllEqual(rt, round_trip_rt)
def testBatchedVariantRoundTripInputRaggedRankInferred(self):
ragged_rank = 1
@ -1510,7 +1510,7 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
expected_rt = ragged_factory_ops.constant([[[0], [1]], [[2], [3]], [[4],
[5]],
[[6], [7]], [[8], [9]]])
self.assertRaggedEqual(decoded_rt, expected_rt)
self.assertAllEqual(decoded_rt, expected_rt)
def testBatchedVariantRoundTripWithInputRaggedRank(self):
ragged_rank = 1
@ -1527,7 +1527,7 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
expected_rt = ragged_factory_ops.constant([[[0], [1]], [[2], [3]], [[4],
[5]],
[[6], [7]], [[8], [9]]])
self.assertRaggedEqual(decoded_rt, expected_rt)
self.assertAllEqual(decoded_rt, expected_rt)
def testFromVariantInvalidParams(self):
rt = ragged_factory_ops.constant([[0], [1], [2], [3]])

View File

@ -1,106 +0,0 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=invalid-name
"""Test utils for tensorflow RaggedTensors."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_tensor_value
class RaggedTensorTestCase(test_util.TensorFlowTestCase):
"""Base class for RaggedTensor test cases."""
def _GetPyList(self, a):
"""Converts a to a nested python list."""
if isinstance(a, ragged_tensor.RaggedTensor):
return self.evaluate(a).to_list()
elif isinstance(a, ops.Tensor):
a = self.evaluate(a)
return a.tolist() if isinstance(a, np.ndarray) else a
elif isinstance(a, np.ndarray):
return a.tolist()
elif isinstance(a, ragged_tensor_value.RaggedTensorValue):
return a.to_list()
else:
return np.array(a).tolist()
def assertRaggedEqual(self, a, b):
"""Asserts that two potentially ragged tensors are equal."""
a_list = self._GetPyList(a)
b_list = self._GetPyList(b)
self.assertEqual(a_list, b_list)
if not (isinstance(a, (list, tuple)) or isinstance(b, (list, tuple))):
a_ragged_rank = a.ragged_rank if ragged_tensor.is_ragged(a) else 0
b_ragged_rank = b.ragged_rank if ragged_tensor.is_ragged(b) else 0
self.assertEqual(a_ragged_rank, b_ragged_rank)
def assertRaggedAlmostEqual(self, a, b, places=7):
a_list = self._GetPyList(a)
b_list = self._GetPyList(b)
self.assertNestedListAlmostEqual(a_list, b_list, places, context='value')
if not (isinstance(a, (list, tuple)) or isinstance(b, (list, tuple))):
a_ragged_rank = a.ragged_rank if ragged_tensor.is_ragged(a) else 0
b_ragged_rank = b.ragged_rank if ragged_tensor.is_ragged(b) else 0
self.assertEqual(a_ragged_rank, b_ragged_rank)
def assertNestedListAlmostEqual(self, a, b, places=7, context='value'):
self.assertEqual(type(a), type(b))
if isinstance(a, (list, tuple)):
self.assertLen(a, len(b), 'Length differs for %s' % context)
for i in range(len(a)):
self.assertNestedListAlmostEqual(a[i], b[i], places,
'%s[%s]' % (context, i))
else:
self.assertAlmostEqual(
a, b, places,
'%s != %s within %s places at %s' % (a, b, places, context))
def eval_to_list(self, tensor):
value = self.evaluate(tensor)
if ragged_tensor.is_ragged(value):
return value.to_list()
elif isinstance(value, np.ndarray):
return value.tolist()
else:
return value
def _eval_tensor(self, tensor):
if ragged_tensor.is_ragged(tensor):
return ragged_tensor_value.RaggedTensorValue(
self._eval_tensor(tensor.values),
self._eval_tensor(tensor.row_splits))
else:
return test_util.TensorFlowTestCase._eval_tensor(self, tensor)
@staticmethod
def _normalize_pylist(item):
"""Convert all (possibly nested) np.arrays contained in item to list."""
# convert np.arrays in current level to list
if np.ndim(item) == 0:
return item
level = (x.tolist() if isinstance(x, np.ndarray) else x for x in item)
_normalize = RaggedTensorTestCase._normalize_pylist
return [_normalize(el) if np.ndim(el) != 0 else el for el in level]

View File

@ -26,12 +26,11 @@ from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.ragged import ragged_array_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_test_util
from tensorflow.python.platform import googletest
@test_util.run_all_in_graph_and_eager_modes
class RaggedTileOpTest(ragged_test_util.RaggedTensorTestCase,
class RaggedTileOpTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
@parameterized.parameters([
@ -209,7 +208,7 @@ class RaggedTileOpTest(ragged_test_util.RaggedTensorTestCase,
self.assertEqual(tiled.shape.ndims, rt.shape.ndims)
if multiples_tensor is const_multiples:
self.assertEqual(tiled.shape.as_list(), expected_shape)
self.assertRaggedEqual(tiled, expected)
self.assertAllEqual(tiled, expected)
def testRaggedTileWithTensorInput(self):
# When the input is a `Tensor`, ragged_tile just delegates to tf.tile.
@ -218,7 +217,7 @@ class RaggedTileOpTest(ragged_test_util.RaggedTensorTestCase,
expected = [[1, 2, 1, 2], [3, 4, 3, 4],
[1, 2, 1, 2], [3, 4, 3, 4],
[1, 2, 1, 2], [3, 4, 3, 4]] # pyformat: disable
self.assertRaggedEqual(tiled, expected)
self.assertAllEqual(tiled, expected)
if __name__ == '__main__':

View File

@ -28,12 +28,11 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_functional_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_test_util
from tensorflow.python.platform import googletest
@test_util.run_all_in_graph_and_eager_modes
class RaggedTensorToSparseOpTest(ragged_test_util.RaggedTensorTestCase):
class RaggedTensorToSparseOpTest(test_util.TensorFlowTestCase):
def testDocStringExample(self):
rt = ragged_factory_ops.constant([[1, 2, 3], [4], [], [5, 6]])
@ -193,8 +192,8 @@ class RaggedTensorToSparseOpTest(ragged_test_util.RaggedTensorTestCase):
g1, g2 = gradients_impl.gradients(st.values,
[rt1.flat_values, rt2.flat_values])
self.assertRaggedEqual(g1, [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]])
self.assertRaggedEqual(g2, [[2.0, 2.0], [2.0, 2.0], [2.0, 2.0]])
self.assertAllEqual(g1, [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]])
self.assertAllEqual(g2, [[2.0, 2.0], [2.0, 2.0], [2.0, 2.0]])
if __name__ == '__main__':

View File

@ -19,17 +19,16 @@ from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_test_util
from tensorflow.python.platform import googletest
@test_util.run_all_in_graph_and_eager_modes
class RaggedTensorToTensorOpTest(ragged_test_util.RaggedTensorTestCase,
class RaggedTensorToTensorOpTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
def testDocStringExamples(self):
@ -105,10 +104,9 @@ class RaggedTensorToTensorOpTest(ragged_test_util.RaggedTensorTestCase,
self.assertIsInstance(dt, ops.Tensor)
self.assertEqual(rt.dtype, dt.dtype)
self.assertTrue(dt.shape.is_compatible_with(rt.shape))
self.assertAllEqual(self.eval_to_list(dt), expected)
if expected_shape is not None:
dt_shape = array_ops.shape(dt)
self.assertAllEqual(dt_shape, expected_shape)
expected = np.ndarray(expected_shape, buffer=np.array(expected))
self.assertAllEqual(dt, expected)
@parameterized.parameters(
{

View File

@ -24,7 +24,6 @@ import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.ragged import ragged_test_util
from tensorflow.python.ops.ragged import ragged_util
from tensorflow.python.platform import googletest
@ -43,7 +42,7 @@ TENSOR_4D = [[[[('%d%d%d%d' % (i, j, k, l)).encode('utf-8')
@test_util.run_all_in_graph_and_eager_modes
class RaggedUtilTest(ragged_test_util.RaggedTensorTestCase,
class RaggedUtilTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
@parameterized.parameters([

View File

@ -20,13 +20,12 @@ from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.framework import test_util
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_test_util
from tensorflow.python.ops.ragged import ragged_where_op
from tensorflow.python.platform import googletest
@test_util.run_all_in_graph_and_eager_modes
class RaggedWhereOpTest(ragged_test_util.RaggedTensorTestCase,
class RaggedWhereOpTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
@parameterized.parameters([
@ -188,7 +187,7 @@ class RaggedWhereOpTest(ragged_test_util.RaggedTensorTestCase,
]) # pyformat: disable
def testRaggedWhere(self, condition, expected, x=None, y=None):
result = ragged_where_op.where(condition, x, y)
self.assertRaggedEqual(result, expected)
self.assertAllEqual(result, expected)
@parameterized.parameters([
dict(

View File

@ -91,7 +91,6 @@ COMMON_PIP_DEPS = [
"//tensorflow/python/kernel_tests/random:util",
"//tensorflow/python/kernel_tests/signal:test_util",
"//tensorflow/python/kernel_tests/testdata:self_adjoint_eig_op_test_files",
"//tensorflow/python/ops/ragged:ragged_test_util",
"//tensorflow/python/saved_model:saved_model",
"//tensorflow/python/tools:tools_pip",
"//tensorflow/python/tools/api/generator:create_python_api",