Merged commit includes the following changes:

237836124  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Internal changes

--
237831779  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Fix rnn.reset_states() to throw proper error when before .build() is invoked.

    See https://github.com/tensorflow/tensorflow/issues/25852 for more details.

--
237814078  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Fix GatherV2 converter. It assumed that the IGatherLayer has same output dimensions as tf.gather, which is not the case.

--
237814023  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Split py_binary into py_binary and py_library to avoid having py_binary in deps.

--

PiperOrigin-RevId: 237836124
This commit is contained in:
A. Unique TensorFlower 2019-03-11 10:54:44 -07:00 committed by TensorFlower Gardener
parent 2b9900638a
commit 6108cbd1db
7 changed files with 314 additions and 246 deletions

View File

@ -2243,8 +2243,8 @@ Status ConvertSqueeze(OpConverterParams* params) {
// Make sure target dimension is size 1.
if (input_dims[axis] != 1) {
return errors::InvalidArgument(
"Cannot squeeze a dimension which isn't size 1, at ",
node_def.name());
"Cannot squeeze ", axis, "th dimension ", input_dims[axis],
" which isn't size 1, at ", node_def.name());
}
// Mark dim for removal by setting to 0.
input_dims[axis] = 0;
@ -3700,13 +3700,59 @@ Status ConvertGather(OpConverterParams* params) {
int trt_axis = 0;
TF_RETURN_IF_ERROR(ConvertAxis(axis[0], inputs.at(0).GetTrtDims().nbDims,
node_def.name(), &trt_axis));
TRT_TensorOrWeights params_tensor = inputs.at(0);
TRT_TensorOrWeights indices_tensor = inputs.at(1);
if (indices_tensor.batch_size() != 1) {
return errors::InvalidArgument("Only indices with batch 1 are supported.");
}
// Both input are tensors, and the TF gather result will have rank:
// (params.nbDims + 1) + (indices.nbDims + 1) - 1,
// where "+ 1" adds the batch dim.
const int tf_gather_output_rank = params_tensor.GetTrtDims().nbDims +
indices_tensor.GetTrtDims().nbDims + 1;
if (tf_gather_output_rank > nvinfer1::Dims::MAX_DIMS + 1) {
return errors::InvalidArgument(
"Result of gather has dimension greater than ",
nvinfer1::Dims::MAX_DIMS + 1);
}
if (params->validation_only) return Status::OK();
// Note on how IGatherLayer works: if both the data and indices tensors have
// a batch size dimension of size N, it performs:
// for batchid in xrange(N):
// output[batchid, a0, ..., an, i, ..., j, b0, ..., bn] = (
// data[batchid, a0, ..., an, indices[batchid, i, ..., j] b0, ..., bn])
nvinfer1::IGatherLayer* layer = params->converter->network()->addGather(
*const_cast<nvinfer1::ITensor*>(inputs.at(0).tensor()),
*const_cast<nvinfer1::ITensor*>(inputs.at(1).tensor()), trt_axis);
*const_cast<nvinfer1::ITensor*>(params_tensor.tensor()),
*const_cast<nvinfer1::ITensor*>(indices_tensor.tensor()), trt_axis);
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0)));
nvinfer1::ITensor* gather_output = layer->getOutput(0);
nvinfer1::Dims trt_gather_output_dims = gather_output->getDimensions();
// Note for the "- 2": one is for the output batch dim encapsulated by TF-TRT,
// and the other is for the output dimension that is squeezed by IGatherLayer
// because of the implicit batch dim in the indices (see the above note).
if (trt_gather_output_dims.nbDims != tf_gather_output_rank - 2) {
return errors::Internal(
"Get unexpected output dimensions of IGatherLayer. Expect nbDims: ",
tf_gather_output_rank - 2,
", actual nbDims: ", trt_gather_output_dims.nbDims);
}
// Reshape the output so after adding the implicit batch dim it'll match the
// output shape of TF GatherV2.
for (int i = trt_gather_output_dims.nbDims; i > trt_axis; --i) {
trt_gather_output_dims.d[i] = trt_gather_output_dims.d[i - 1];
}
trt_gather_output_dims.d[trt_axis] = 1;
++trt_gather_output_dims.nbDims;
const nvinfer1::ITensor* output_tensor = nullptr;
TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
TRT_TensorOrWeights(gather_output), trt_gather_output_dims,
/*validation_only=*/false, &output_tensor));
params->outputs->push_back(
TRT_TensorOrWeights(const_cast<nvinfer1::ITensor*>(output_tensor)));
return Status::OK();
}

View File

@ -3400,14 +3400,23 @@ void TestConvertGather(OpConverterTest* test) {
// Input is the same {1, 2, 3, 4, 5, 6} for all cases.
const int kGatherOKCases = 5;
const std::vector<CType> params_input = {CType(1), CType(2), CType(3),
CType(4), CType(5), CType(6)};
TestParams ok_params[kGatherOKCases] = {
// Vector indices (output is rank(params)).
TestParams{{1, 2, 3}, {1}, {0}, 3, {1, 2, 1}, {1, 4}},
TestParams{{1, 2, 3}, {1}, {1}, 3, {1, 2, 1}, {2, 5}},
TestParams{{1, 2, 3}, {1}, {2}, -1, {1, 2, 1}, {3, 6}},
TestParams{{1, 2, 3}, {3}, {2, 0, 1}, 3, {1, 2, 3}, {3, 1, 2, 6, 4, 5}},
// Higher rank indices (output is rank(params) + rank(indices) - 1).
TestParams{{1, 2, 3}, {1, 1}, {0}, 2, {1, 1, 1, 3}, {1, 2, 3}},
// Indices are always of rank>1, and output rank is
// rank(params) + rank(indices) - 1.
// TODO(laigd): do we support 0-rank ITensor as indices?
TestParams{{1, 2, 3}, {1}, {0}, 3, {1, 2, 1, 1}, {1, 4}},
TestParams{{1, 2, 3}, {1}, {1}, 3, {1, 2, 1, 1}, {2, 5}},
TestParams{{1, 2, 3}, {1}, {2}, -1, {1, 2, 1, 1}, {3, 6}},
TestParams{
{1, 2, 3}, {3}, {2, 0, 1}, 3, {1, 2, 1, 3}, {3, 1, 2, 6, 4, 5}},
TestParams{{3, 2},
{2, 2},
{0, 0, 1, 0},
2,
{3, 1, 2, 2},
{1, 1, 2, 1, 3, 3, 4, 3, 5, 5, 6, 5}},
};
// Ok.
@ -3426,14 +3435,12 @@ void TestConvertGather(OpConverterTest* test) {
output.tensor()->getDimensions());
// Create input in CType and convert expected output to CType.
std::vector<CType> inputs = {CType(1), CType(2), CType(3),
CType(4), CType(5), CType(6)};
std::vector<CType> converted_expected_output(
ok_params[i].expected_output.begin(),
ok_params[i].expected_output.end());
const DataVec input_data{
{"params", test::AsTensor<CType>(inputs)},
{"params", test::AsTensor<CType>(params_input)},
{"indices", test::AsTensor<int32>(ok_params[i].indices)}};
DataVec output_data{
{"my_gather",

View File

@ -72,7 +72,7 @@ py_test(
"//tensorflow/lite/python:lite",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform",
"//tensorflow/python/tools:optimize_for_inference",
"//tensorflow/python/tools:optimize_for_inference_main_lib",
"//third_party/py/numpy",
"@six_archive//:six",
],
@ -95,7 +95,7 @@ py_test(
"//tensorflow/lite/python:lite",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform",
"//tensorflow/python/tools:optimize_for_inference",
"//tensorflow/python/tools:optimize_for_inference_main_lib",
"//third_party/py/numpy",
"@six_archive//:six",
],
@ -118,7 +118,7 @@ py_test(
"//tensorflow/lite/python:lite",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform",
"//tensorflow/python/tools:optimize_for_inference",
"//tensorflow/python/tools:optimize_for_inference_main_lib",
"//third_party/py/numpy",
"@six_archive//:six",
],

View File

@ -1,245 +1,245 @@
tensorflow/contrib/tpu/profiler/pip_package/BUILD
tensorflow/contrib/tpu/profiler/pip_package/setup.py
tensorflow/contrib/tpu/profiler/pip_package/README
tensorflow/contrib/tpu/profiler/pip_package/build_pip_package.sh
tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/__init__.py
tensorflow/contrib/mpi/BUILD
tensorflow/tools/ci_build/remote/BUILD
tensorflow/tools/pip_package/README
tensorflow/tools/pip_package/MANIFEST.in
tensorflow/tools/pip_package/simple_console.py
tensorflow/tools/pip_package/build_pip_package.sh
tensorflow/tools/pip_package/check_load_py_test.py
tensorflow/tools/pip_package/pip_smoke_test.py
tensorflow/tools/pip_package/simple_console_for_windows.py
tensorflow/tools/pip_package/setup.py
tensorflow/tools/pip_package/BUILD
tensorflow/tools/lib_package/concat_licenses.sh
tensorflow/tools/lib_package/libtensorflow_test.c
tensorflow/tools/lib_package/LibTensorFlowTest.java
tensorflow/tools/lib_package/BUILD
tensorflow/tools/lib_package/libtensorflow_test.sh
tensorflow/tools/lib_package/README.md
tensorflow/tools/lib_package/libtensorflow_java_test.sh
tensorflow/tools/def_file_filter/def_file_filter_configure.bzl
tensorflow/tools/def_file_filter/BUILD
tensorflow/tools/def_file_filter/BUILD.tpl
tensorflow/tools/def_file_filter/def_file_filter.py.tpl
tensorflow/third_party/mkl/MKL_LICENSE
tensorflow/third_party/mkl/LICENSE
tensorflow/third_party/mkl/BUILD
tensorflow/third_party/mkl/mkl.BUILD
tensorflow/third_party/mkl/build_defs.bzl
tensorflow/third_party/backports_weakref.BUILD
tensorflow/third_party/toolchains/clang6/BUILD
tensorflow/third_party/toolchains/clang6/README.md
tensorflow/third_party/toolchains/clang6/repo.bzl
tensorflow/third_party/toolchains/clang6/CROSSTOOL.tpl
tensorflow/third_party/toolchains/clang6/clang.BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc7-nvcc-cuda10.0/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/py3/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda9.0/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/BUILD
tensorflow/third_party/toolchains/preconfig/generate/workspace.bzl
tensorflow/third_party/toolchains/preconfig/generate/containers.bzl
tensorflow/third_party/toolchains/preconfig/generate/generate.bzl
tensorflow/third_party/toolchains/preconfig/generate/archives.bzl
tensorflow/third_party/toolchains/preconfig/generate/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/dummy_toolchain.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py3/BUILD
tensorflow/third_party/toolchains/preconfig/win_1803/bazel_018/BUILD
tensorflow/third_party/toolchains/preconfig/win_1803/bazel_018/dummy_toolchain.bzl
tensorflow/third_party/toolchains/preconfig/win_1803/py36/BUILD
tensorflow/third_party/toolchains/preconfig/win_1803/BUILD
tensorflow/third_party/systemlibs/nsync.BUILD
tensorflow/third_party/systemlibs/absl_py.absl.flags.BUILD
tensorflow/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD
tensorflow/third_party/systemlibs/build_defs.bzl.tpl
tensorflow/third_party/systemlibs/curl.BUILD
tensorflow/third_party/systemlibs/cython.BUILD
tensorflow/third_party/systemlibs/astor.BUILD
tensorflow/third_party/systemlibs/jsoncpp.BUILD
tensorflow/third_party/systemlibs/png.BUILD
tensorflow/third_party/systemlibs/pcre.BUILD
tensorflow/third_party/systemlibs/grpc.BUILD
tensorflow/third_party/systemlibs/protobuf.BUILD
tensorflow/third_party/systemlibs/double_conversion.BUILD
tensorflow/third_party/systemlibs/six.BUILD
tensorflow/third_party/systemlibs/zlib.BUILD
tensorflow/third_party/systemlibs/lmdb.BUILD
tensorflow/third_party/systemlibs/sqlite.BUILD
tensorflow/third_party/systemlibs/gast.BUILD
tensorflow/third_party/systemlibs/absl_py.BUILD
tensorflow/third_party/systemlibs/boringssl.BUILD
tensorflow/third_party/systemlibs/BUILD.tpl
tensorflow/third_party/systemlibs/BUILD
tensorflow/third_party/systemlibs/termcolor.BUILD
tensorflow/third_party/systemlibs/gif.BUILD
tensorflow/third_party/systemlibs/protobuf.bzl
tensorflow/third_party/systemlibs/snappy.BUILD
tensorflow/third_party/systemlibs/googleapis.BUILD
tensorflow/third_party/systemlibs/google_cloud_cpp.BUILD
tensorflow/third_party/systemlibs/re2.BUILD
tensorflow/third_party/systemlibs/swig.BUILD
tensorflow/third_party/systemlibs/syslibs_configure.bzl
tensorflow/third_party/systemlibs/absl_py.absl.testing.BUILD
tensorflow/third_party/pprof.BUILD
tensorflow/third_party/toolchains/remote/execution.bzl.tpl
tensorflow/third_party/toolchains/remote/BUILD.tpl
tensorflow/third_party/toolchains/remote/BUILD
tensorflow/third_party/toolchains/remote/configure.bzl
tensorflow/third_party/toolchains/cpus/py3/BUILD
tensorflow/third_party/toolchains/cpus/py/BUILD
tensorflow/third_party/toolchains/cpus/arm/arm_compiler_configure.bzl
tensorflow/third_party/toolchains/cpus/arm/CROSSTOOL.tpl
tensorflow/third_party/toolchains/cpus/arm/BUILD
tensorflow/third_party/toolchains/cpus/py3/BUILD
tensorflow/third_party/toolchains/cpus/py/BUILD
tensorflow/third_party/toolchains/remote/configure.bzl
tensorflow/third_party/toolchains/remote/BUILD.tpl
tensorflow/third_party/toolchains/remote/BUILD
tensorflow/third_party/toolchains/remote/execution.bzl.tpl
tensorflow/third_party/toolchains/BUILD
tensorflow/third_party/gpus/BUILD
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
tensorflow/third_party/gpus/crosstool/CROSSTOOL.tpl
tensorflow/third_party/gpus/crosstool/CROSSTOOL_hipcc.tpl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py3/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/dummy_toolchain.bzl
tensorflow/third_party/toolchains/preconfig/win_1803/bazel_018/BUILD
tensorflow/third_party/toolchains/preconfig/win_1803/bazel_018/dummy_toolchain.bzl
tensorflow/third_party/toolchains/preconfig/win_1803/BUILD
tensorflow/third_party/toolchains/preconfig/win_1803/py36/BUILD
tensorflow/third_party/toolchains/preconfig/generate/containers.bzl
tensorflow/third_party/toolchains/preconfig/generate/workspace.bzl
tensorflow/third_party/toolchains/preconfig/generate/archives.bzl
tensorflow/third_party/toolchains/preconfig/generate/generate.bzl
tensorflow/third_party/toolchains/preconfig/generate/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/py3/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda9.0/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc7-nvcc-cuda10.0/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/BUILD
tensorflow/third_party/toolchains/clang6/repo.bzl
tensorflow/third_party/toolchains/clang6/CROSSTOOL.tpl
tensorflow/third_party/toolchains/clang6/BUILD
tensorflow/third_party/toolchains/clang6/clang.BUILD
tensorflow/third_party/toolchains/clang6/README.md
tensorflow/third_party/farmhash.BUILD
tensorflow/third_party/git/BUILD.tpl
tensorflow/third_party/git/git_configure.bzl
tensorflow/third_party/git/BUILD
tensorflow/third_party/cub.BUILD
tensorflow/third_party/gpus/cuda_configure.bzl
tensorflow/third_party/gpus/rocm/build_defs.bzl.tpl
tensorflow/third_party/gpus/rocm/BUILD.tpl
tensorflow/third_party/gpus/rocm/BUILD
tensorflow/third_party/gpus/rocm/rocm_config.h.tpl
tensorflow/third_party/gpus/rocm_configure.bzl
tensorflow/third_party/gpus/crosstool/LICENSE
tensorflow/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl
tensorflow/third_party/gpus/crosstool/CROSSTOOL_hipcc.tpl
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl
tensorflow/third_party/gpus/crosstool/CROSSTOOL.tpl
tensorflow/third_party/gpus/crosstool/BUILD.tpl
tensorflow/third_party/gpus/crosstool/BUILD
tensorflow/third_party/gpus/cuda/build_defs.bzl.tpl
tensorflow/third_party/gpus/cuda/LICENSE
tensorflow/third_party/gpus/cuda/BUILD.tpl
tensorflow/third_party/gpus/cuda/BUILD.windows.tpl
tensorflow/third_party/gpus/cuda/cuda_config.h.tpl
tensorflow/third_party/gpus/cuda/BUILD.tpl
tensorflow/third_party/gpus/cuda/BUILD
tensorflow/third_party/gpus/cuda/build_defs.bzl.tpl
tensorflow/third_party/gpus/rocm/rocm_config.h.tpl
tensorflow/third_party/gpus/rocm/BUILD
tensorflow/third_party/gpus/rocm/BUILD.tpl
tensorflow/third_party/gpus/rocm/build_defs.bzl.tpl
tensorflow/third_party/gpus/cuda_configure.bzl
tensorflow/third_party/gpus/rocm_configure.bzl
tensorflow/third_party/snappy.BUILD
tensorflow/third_party/cython.BUILD
tensorflow/third_party/farmhash.BUILD
tensorflow/third_party/eigen3/Eigen/Cholesky
tensorflow/third_party/eigen3/Eigen/QR
tensorflow/third_party/eigen3/Eigen/LU
tensorflow/third_party/eigen3/Eigen/Core
tensorflow/third_party/eigen3/Eigen/SVD
tensorflow/third_party/eigen3/Eigen/Eigenvalues
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool
tensorflow/third_party/eigen3/unsupported/Eigen/SpecialFunctions
tensorflow/third_party/eigen3/unsupported/Eigen/MatrixFunctions
tensorflow/third_party/eigen3/gpu_packet_math.patch
tensorflow/third_party/eigen3/LICENSE
tensorflow/third_party/eigen3/BUILD
tensorflow/third_party/systemlibs/build_defs.bzl.tpl
tensorflow/third_party/systemlibs/absl_py.BUILD
tensorflow/third_party/systemlibs/curl.BUILD
tensorflow/third_party/systemlibs/termcolor.BUILD
tensorflow/third_party/systemlibs/absl_py.absl.flags.BUILD
tensorflow/third_party/systemlibs/grpc.BUILD
tensorflow/third_party/systemlibs/swig.BUILD
tensorflow/third_party/systemlibs/protobuf.bzl
tensorflow/third_party/systemlibs/protobuf.BUILD
tensorflow/third_party/systemlibs/BUILD
tensorflow/third_party/systemlibs/google_cloud_cpp.BUILD
tensorflow/third_party/systemlibs/astor.BUILD
tensorflow/third_party/systemlibs/six.BUILD
tensorflow/third_party/systemlibs/absl_py.absl.testing.BUILD
tensorflow/third_party/systemlibs/boringssl.BUILD
tensorflow/third_party/systemlibs/nsync.BUILD
tensorflow/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD
tensorflow/third_party/systemlibs/gif.BUILD
tensorflow/third_party/systemlibs/pcre.BUILD
tensorflow/third_party/systemlibs/BUILD.tpl
tensorflow/third_party/systemlibs/snappy.BUILD
tensorflow/third_party/systemlibs/gast.BUILD
tensorflow/third_party/systemlibs/cython.BUILD
tensorflow/third_party/systemlibs/double_conversion.BUILD
tensorflow/third_party/systemlibs/zlib.BUILD
tensorflow/third_party/systemlibs/jsoncpp.BUILD
tensorflow/third_party/systemlibs/re2.BUILD
tensorflow/third_party/systemlibs/lmdb.BUILD
tensorflow/third_party/systemlibs/googleapis.BUILD
tensorflow/third_party/systemlibs/png.BUILD
tensorflow/third_party/systemlibs/syslibs_configure.bzl
tensorflow/third_party/systemlibs/sqlite.BUILD
tensorflow/third_party/python_runtime/BUILD
tensorflow/third_party/sycl/crosstool/BUILD
tensorflow/third_party/ngraph/LICENSE
tensorflow/third_party/ngraph/tbb.BUILD
tensorflow/third_party/ngraph/BUILD
tensorflow/third_party/ngraph/ngraph.BUILD
tensorflow/third_party/ngraph/build_defs.bzl
tensorflow/third_party/ngraph/NGRAPH_LICENSE
tensorflow/third_party/ngraph/ngraph_tf.BUILD
tensorflow/third_party/ngraph/nlohmann_json.BUILD
tensorflow/third_party/clang_toolchain/download_clang.bzl
tensorflow/third_party/clang_toolchain/BUILD
tensorflow/third_party/clang_toolchain/cc_configure_clang.bzl
tensorflow/third_party/gast.BUILD
tensorflow/third_party/llvm/BUILD
tensorflow/third_party/llvm/expand_cmake_vars.py
tensorflow/third_party/llvm/llvm.autogenerated.BUILD
tensorflow/third_party/llvm/llvm.bzl
tensorflow/third_party/icu/udata.patch
tensorflow/third_party/nccl/archive.BUILD
tensorflow/third_party/nccl/LICENSE
tensorflow/third_party/nccl/system.BUILD.tpl
tensorflow/third_party/nccl/nccl_configure.bzl
tensorflow/third_party/nccl/build_defs.bzl.tpl
tensorflow/third_party/nccl/BUILD
tensorflow/third_party/fft2d/BUILD
tensorflow/third_party/fft2d/fft.h
tensorflow/third_party/fft2d/LICENSE
tensorflow/third_party/fft2d/fft2d.BUILD
tensorflow/third_party/boringssl/BUILD
tensorflow/third_party/mpi/.gitignore
tensorflow/third_party/mpi/BUILD
tensorflow/third_party/tensorrt/LICENSE
tensorflow/third_party/tensorrt/BUILD
tensorflow/third_party/tensorrt/build_defs.bzl.tpl
tensorflow/third_party/tensorrt/BUILD.tpl
tensorflow/third_party/tensorrt/tensorrt_configure.bzl
tensorflow/third_party/kafka/config.patch
tensorflow/third_party/kafka/BUILD
tensorflow/third_party/android/BUILD
tensorflow/third_party/android/android.bzl.tpl
tensorflow/third_party/android/android_configure.bzl
tensorflow/third_party/android/android_configure.BUILD.tpl
tensorflow/third_party/tflite_smartreply.BUILD
tensorflow/third_party/gpus/BUILD
tensorflow/third_party/common.bzl
tensorflow/third_party/tflite_mobilenet_quant.BUILD
tensorflow/third_party/linenoise.BUILD
tensorflow/third_party/curl.BUILD
tensorflow/third_party/mkl_dnn/LICENSE
tensorflow/third_party/mkl_dnn/mkldnn.BUILD
tensorflow/third_party/pcre.BUILD
tensorflow/third_party/linenoise.BUILD
tensorflow/third_party/sqlite.BUILD
tensorflow/third_party/common.bzl
tensorflow/third_party/com_google_absl.BUILD
tensorflow/third_party/pprof.BUILD
tensorflow/third_party/BUILD
tensorflow/third_party/tflite_mobilenet_quant.BUILD
tensorflow/third_party/lmdb.BUILD
tensorflow/third_party/git/BUILD.tpl
tensorflow/third_party/git/BUILD
tensorflow/third_party/git/git_configure.bzl
tensorflow/third_party/protobuf/BUILD
tensorflow/third_party/enum34.BUILD
tensorflow/third_party/tflite_mobilenet.BUILD
tensorflow/third_party/py/BUILD
tensorflow/third_party/py/BUILD.tpl
tensorflow/third_party/py/numpy/BUILD
tensorflow/third_party/py/python_configure.bzl
tensorflow/third_party/termcolor.BUILD
tensorflow/third_party/png_fix_rpi.patch
tensorflow/third_party/swig.BUILD
tensorflow/third_party/astor.BUILD
tensorflow/third_party/fft2d/LICENSE
tensorflow/third_party/fft2d/fft2d.BUILD
tensorflow/third_party/fft2d/fft.h
tensorflow/third_party/fft2d/BUILD
tensorflow/third_party/ngraph/LICENSE
tensorflow/third_party/ngraph/build_defs.bzl
tensorflow/third_party/ngraph/tbb.BUILD
tensorflow/third_party/ngraph/ngraph.BUILD
tensorflow/third_party/ngraph/nlohmann_json.BUILD
tensorflow/third_party/ngraph/BUILD
tensorflow/third_party/ngraph/ngraph_tf.BUILD
tensorflow/third_party/ngraph/NGRAPH_LICENSE
tensorflow/third_party/grpc/BUILD
tensorflow/third_party/curl.BUILD
tensorflow/third_party/arm_neon_2_x86_sse.BUILD
tensorflow/third_party/cython.BUILD
tensorflow/third_party/icu/udata.patch
tensorflow/third_party/astor.BUILD
tensorflow/third_party/jsoncpp.BUILD
tensorflow/third_party/sycl/crosstool/BUILD
tensorflow/third_party/llvm/llvm.autogenerated.BUILD
tensorflow/third_party/llvm/expand_cmake_vars.py
tensorflow/third_party/llvm/llvm.bzl
tensorflow/third_party/llvm/BUILD
tensorflow/third_party/png.BUILD
tensorflow/third_party/googleapis.BUILD
tensorflow/third_party/mpi_collectives/BUILD
tensorflow/third_party/nanopb.BUILD
tensorflow/third_party/gif.BUILD
tensorflow/third_party/arm_neon_2_x86_sse.BUILD
tensorflow/third_party/codegen.BUILD
tensorflow/third_party/enum34.BUILD
tensorflow/third_party/kafka/config.patch
tensorflow/third_party/kafka/BUILD
tensorflow/third_party/pcre.BUILD
tensorflow/third_party/mpi/BUILD
tensorflow/third_party/mpi/.gitignore
tensorflow/third_party/clang_toolchain/BUILD
tensorflow/third_party/clang_toolchain/download_clang.bzl
tensorflow/third_party/clang_toolchain/cc_configure_clang.bzl
tensorflow/third_party/tflite_ovic_testdata.BUILD
tensorflow/third_party/repo.bzl
tensorflow/third_party/png_fix_rpi.patch
tensorflow/third_party/py/python_configure.bzl
tensorflow/third_party/py/BUILD.tpl
tensorflow/third_party/py/BUILD
tensorflow/third_party/py/numpy/BUILD
tensorflow/third_party/double_conversion.BUILD
tensorflow/third_party/six.BUILD
tensorflow/third_party/tflite_mobilenet_float.BUILD
tensorflow/third_party/repo.bzl
tensorflow/third_party/codegen.BUILD
tensorflow/third_party/cub.BUILD
tensorflow/third_party/jsoncpp.BUILD
tensorflow/third_party/tflite_ovic_testdata.BUILD
tensorflow/third_party/libxsmm.BUILD
tensorflow/third_party/zlib.BUILD
tensorflow/third_party/lmdb.BUILD
tensorflow/third_party/nanopb.BUILD
tensorflow/third_party/android/android.bzl.tpl
tensorflow/third_party/android/BUILD
tensorflow/third_party/android/android_configure.BUILD.tpl
tensorflow/third_party/android/android_configure.bzl
tensorflow/third_party/tflite_mobilenet_float.BUILD
tensorflow/third_party/sqlite.BUILD
tensorflow/third_party/tensorrt/build_defs.bzl.tpl
tensorflow/third_party/tensorrt/LICENSE
tensorflow/third_party/tensorrt/tensorrt_configure.bzl
tensorflow/third_party/tensorrt/BUILD.tpl
tensorflow/third_party/tensorrt/BUILD
tensorflow/third_party/gast.BUILD
tensorflow/third_party/mpi_collectives/BUILD
tensorflow/third_party/libxsmm.BUILD
tensorflow/third_party/eigen.BUILD
tensorflow/third_party/com_google_absl.BUILD
tensorflow/third_party/eigen3/LICENSE
tensorflow/third_party/eigen3/gpu_packet_math.patch
tensorflow/third_party/eigen3/BUILD
tensorflow/third_party/eigen3/unsupported/Eigen/MatrixFunctions
tensorflow/third_party/eigen3/unsupported/Eigen/SpecialFunctions
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool
tensorflow/third_party/eigen3/Eigen/QR
tensorflow/third_party/eigen3/Eigen/SVD
tensorflow/third_party/eigen3/Eigen/LU
tensorflow/third_party/eigen3/Eigen/Cholesky
tensorflow/third_party/eigen3/Eigen/Eigenvalues
tensorflow/third_party/eigen3/Eigen/Core
tensorflow/third_party/BUILD
tensorflow/third_party/termcolor.BUILD
tensorflow/third_party/gif.BUILD
tensorflow/third_party/tflite_mobilenet.BUILD
tensorflow/third_party/mkl/LICENSE
tensorflow/third_party/mkl/build_defs.bzl
tensorflow/third_party/mkl/mkl.BUILD
tensorflow/third_party/mkl/MKL_LICENSE
tensorflow/third_party/mkl/BUILD
tensorflow/third_party/nccl/build_defs.bzl.tpl
tensorflow/third_party/nccl/LICENSE
tensorflow/third_party/nccl/nccl_configure.bzl
tensorflow/third_party/nccl/archive.BUILD
tensorflow/third_party/nccl/BUILD
tensorflow/third_party/nccl/system.BUILD.tpl
tensorflow/third_party/snappy.BUILD
tensorflow/third_party/python_runtime/BUILD
tensorflow/third_party/googleapis.BUILD
tensorflow/third_party/boringssl/BUILD
tensorflow/third_party/protobuf/BUILD
tensorflow/third_party/backports_weakref.BUILD
tensorflow/third_party/tflite_smartreply.BUILD
tensorflow/third_party/swig.BUILD
tensorflow/compat_template.__init__.py
tensorflow/tools/lib_package/libtensorflow_test.sh
tensorflow/tools/lib_package/libtensorflow_java_test.sh
tensorflow/tools/lib_package/libtensorflow_test.c
tensorflow/tools/lib_package/concat_licenses.sh
tensorflow/tools/lib_package/LibTensorFlowTest.java
tensorflow/tools/lib_package/BUILD
tensorflow/tools/lib_package/README.md
tensorflow/tools/pip_package/check_load_py_test.py
tensorflow/tools/pip_package/simple_console.py
tensorflow/tools/pip_package/pip_smoke_test.py
tensorflow/tools/pip_package/BUILD
tensorflow/tools/pip_package/simple_console_for_windows.py
tensorflow/tools/pip_package/build_pip_package.sh
tensorflow/tools/pip_package/README
tensorflow/tools/pip_package/setup.py
tensorflow/tools/pip_package/MANIFEST.in
tensorflow/tools/ci_build/remote/BUILD
tensorflow/tools/def_file_filter/def_file_filter.py.tpl
tensorflow/tools/def_file_filter/BUILD.tpl
tensorflow/tools/def_file_filter/BUILD
tensorflow/tools/def_file_filter/def_file_filter_configure.bzl
tensorflow/api_template.__init__.py
tensorflow/contrib/tpu/profiler/pip_package/BUILD
tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/__init__.py
tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
tensorflow/contrib/tpu/profiler/pip_package/build_pip_package.sh
tensorflow/contrib/tpu/profiler/pip_package/README
tensorflow/contrib/tpu/profiler/pip_package/setup.py
tensorflow/contrib/mpi/BUILD
tensorflow/__init__.py
tensorflow/stream_executor/build_defs.bzl
tensorflow/api_template_v1.__init__.py
tensorflow/compat_template_v1.__init__.py
tensorflow/compat_template.__init__.py
tensorflow/api_template.__init__.py
tensorflow/__init__.py
tensorflow/compat_template_v1.__init__.py

View File

@ -47,7 +47,7 @@ py_library(
":cli_test_utils",
":debug_py",
":grpc_debug_test_server",
":grpc_tensorflow_server",
":grpc_tensorflow_server_lib",
":offline_analyzer_lib",
":session_debug_testlib",
":source_remote",
@ -1124,7 +1124,7 @@ cuda_py_test(
"//tensorflow/python:platform_test",
"//tensorflow/python:variables",
],
data = [":grpc_tensorflow_server"],
data = [":grpc_tensorflow_server_lib"],
grpc_enabled = True,
tags = [
"no_oss", # Incompatible with bazel_pip.

View File

@ -833,10 +833,14 @@ class RNN(Layer):
def reset_states(self, states=None):
if not self.stateful:
raise AttributeError('Layer must be stateful.')
if self.time_major:
batch_size = self.input_spec[0].shape[1]
spec_shape = None if self.input_spec is None else self.input_spec[0].shape
if spec_shape is None:
# It is possible to have spec shape to be None, eg when construct a RNN
# with a custom cell, or standard RNN layers (LSTM/GRU) which we only know
# it has 3 dim input, but not its full shape spec before build().
batch_size = None
else:
batch_size = self.input_spec[0].shape[0]
batch_size = spec_shape[1] if self.time_major else spec_shape[0]
if not batch_size:
raise ValueError('If a RNN is stateful, it needs to know '
'its batch size. Specify the batch size '

View File

@ -1315,6 +1315,17 @@ class RNNTest(keras_parameterized.TestCase):
model = keras.Model([inputs, state_h, state_c], decoder_out)
model.reset_states()
def test_reset_states(self):
# See https://github.com/tensorflow/tensorflow/issues/25852
with self.assertRaisesRegexp(ValueError, 'it needs to know its batch size'):
simple_rnn = keras.layers.SimpleRNN(1, stateful=True)
simple_rnn.reset_states()
with self.assertRaisesRegexp(ValueError, 'it needs to know its batch size'):
cell = Minimal2DRNNCell(1, 2)
custom_rnn = keras.layers.RNN(cell, stateful=True)
custom_rnn.reset_states()
class Minimal2DRNNCell(keras.layers.Layer):
"""The minimal 2D RNN cell is a simple combination of 2 1-D RNN cell.